diff --git a/python/tvm/relay/op/contrib/dnnl.py b/python/tvm/relay/op/contrib/dnnl.py index 05b588051a1c..6d4fe0d81260 100644 --- a/python/tvm/relay/op/contrib/dnnl.py +++ b/python/tvm/relay/op/contrib/dnnl.py @@ -67,22 +67,39 @@ def _func_wrapper(expr): _register_external_op_helper("nn.batch_norm") +_register_external_op_helper("nn.conv1d") _register_external_op_helper("nn.conv2d") +_register_external_op_helper("nn.conv3d") +_register_external_op_helper("nn.conv2d_transpose") +_register_external_op_helper("nn.conv3d_transpose") _register_external_op_helper("nn.dense") +_register_external_op_helper("nn.max_pool2d") +_register_external_op_helper("nn.avg_pool2d") +_register_external_op_helper("nn.max_pool3d") +_register_external_op_helper("nn.avg_pool3d") +_register_external_op_helper("abs") +_register_external_op_helper("clip") +_register_external_op_helper("exp") +_register_external_op_helper("log") +_register_external_op_helper("sqrt") +_register_external_op_helper("round") +_register_external_op_helper("logsumexp") _register_external_op_helper("nn.relu") +_register_external_op_helper("nn.leaky_relu") _register_external_op_helper("tanh") _register_external_op_helper("sigmoid") +_register_external_op_helper("nn.softmax") _register_external_op_helper("add") _register_external_op_helper("multiply") -def make_conv_pattern(with_bias=True, with_eltwise=None): - """Create patterns related to nn.conv2d. +def make_conv_pattern(conv_name, with_bias=True, with_eltwise=None): + """Create patterns related to conv and deconv. Parameters ---------- with_bias : bool - Whether attach `bias_add` to `nn.conv2d`. + Whether attach `bias_add` to `conv / deconv`. with_eltwise : str The attached elementwise post-op name. Returns @@ -93,7 +110,7 @@ def make_conv_pattern(with_bias=True, with_eltwise=None): data = wildcard() weight = wildcard() bias = wildcard() - conv = is_op("nn.conv2d")(data, weight) + conv = is_op(conv_name)(data, weight) if with_bias: conv_out = is_op("add")(conv, bias) else: @@ -146,15 +163,19 @@ def make_dnnl_pattern(op, with_bias, with_eltwise): pattern : Tuple(pattern_name, CallPattern) Created pattern name, along with its CallPattern. """ - pat_name = "dnnl." + op + pat_name = op.replace("nn", "dnnl") pat_name += "_bias" if with_bias else "" pat_name += ("_" + with_eltwise.split(".")[-1]) if with_eltwise else "" - if op == "conv2d": - dnnl_pattern = (pat_name, make_conv_pattern(with_bias, with_eltwise)) - elif op == "dense": + if "conv" in op: + dnnl_pattern = (pat_name, make_conv_pattern(op, with_bias, with_eltwise)) + elif op == "nn.dense": dnnl_pattern = (pat_name, make_dense_pattern(with_bias, with_eltwise)) else: - logger.warning("Currently, only conv2d and dense op are supported, but got %s.", op) + logger.warning( + "Currently, only conv1d, conv2d, conv2d_transpose, conv3d_transpose and " + "dense op are supported, but got %s.", + op, + ) dnnl_pattern = () return dnnl_pattern @@ -174,8 +195,15 @@ def pattern_table(): for elt in elt_list: if not with_bias and not elt: return dnnl_patterns - dnnl_patterns.append(make_dnnl_pattern("conv2d", with_bias, elt)) - dnnl_patterns.append(make_dnnl_pattern("dense", with_bias, elt)) + for conv_name in [ + "nn.conv1d", + "nn.conv2d", + "nn.conv3d", + "nn.conv2d_transpose", + "nn.conv3d_transpose", + ]: + dnnl_patterns.append(make_dnnl_pattern(conv_name, with_bias, elt)) + dnnl_patterns.append(make_dnnl_pattern("nn.dense", with_bias, elt)) return dnnl_patterns diff --git a/src/relay/backend/contrib/dnnl/codegen.cc b/src/relay/backend/contrib/dnnl/codegen.cc index b1b2f580cf94..7971b9cf67d2 100644 --- a/src/relay/backend/contrib/dnnl/codegen.cc +++ b/src/relay/backend/contrib/dnnl/codegen.cc @@ -31,6 +31,7 @@ #include #include +#include #include #include "../../utils.h" @@ -439,6 +440,23 @@ class DNNLJSONSerializer : public backend::contrib::JSONSerializer { using JSONGraphNode = tvm::runtime::json::JSONGraphNode; using JSONGraphNodeEntry = tvm::runtime::json::JSONGraphNodeEntry; + std::map op_map{ + {"bias", "add"}, + {"relu", "nn.relu"}, + {"tanh", "tanh"}, + {"sigmoid", "sigmoid"}, + }; + + std::vector ParsingOpList(std::string op, std::string pattern_name) { + std::vector op_list = {"nn." + op}; + for (auto& t : op_map) { + if (pattern_name.find(t.first) != std::string::npos) { + op_list.push_back(t.second); + } + } + return op_list; + } + public: DNNLJSONSerializer(const std::string& symbol, const Expr& expr) : JSONSerializer(symbol, expr) {} @@ -453,28 +471,29 @@ class DNNLJSONSerializer : public backend::contrib::JSONSerializer { ICHECK(comp.defined()) << "DNNL JSON runtime only supports composite functions."; name = comp.value(); - if (name == "dnnl.conv2d_bias_relu") { - call = GetRootCall(fn->body.as(), 2, {"nn.conv2d", "add", "nn.relu"}); - } else if (name == "dnnl.conv2d_bias_tanh") { - call = GetRootCall(fn->body.as(), 2, {"nn.conv2d", "add", "tanh"}); - ICHECK(call->op.as()) << "Not op node"; - } else if (name == "dnnl.conv2d_bias_sigmoid") { - call = GetRootCall(fn->body.as(), 2, {"nn.conv2d", "add", "sigmoid"}); + if (name.find("dnnl.conv2d_transpose") != std::string::npos) { + std::vector op_list = ParsingOpList("conv2d_transpose", name); + call = GetRootCall(fn->body.as(), op_list.size() - 1, op_list); ICHECK(call->op.as()) << "Not op node"; - } else if (name == "dnnl.conv2d_bias") { - call = GetRootCall(fn->body.as(), 1, {"nn.conv2d", "add"}); + } else if (name.find("dnnl.conv3d_transpose") != std::string::npos) { + std::vector op_list = ParsingOpList("conv3d_transpose", name); + call = GetRootCall(fn->body.as(), op_list.size() - 1, op_list); ICHECK(call->op.as()) << "Not op node"; - } else if (name == "dnnl.conv2d_relu") { - call = GetRootCall(fn->body.as(), 1, {"nn.conv2d", "nn.relu"}); + } else if (name.find("dnnl.conv1d") != std::string::npos) { + std::vector op_list = ParsingOpList("conv1d", name); + call = GetRootCall(fn->body.as(), op_list.size() - 1, op_list); ICHECK(call->op.as()) << "Not op node"; - } else if (name == "dnnl.conv2d_tanh") { - call = GetRootCall(fn->body.as(), 1, {"nn.conv2d", "tanh"}); + } else if (name.find("dnnl.conv2d") != std::string::npos) { + std::vector op_list = ParsingOpList("conv2d", name); + call = GetRootCall(fn->body.as(), op_list.size() - 1, op_list); ICHECK(call->op.as()) << "Not op node"; - } else if (name == "dnnl.conv2d_sigmoid") { - call = GetRootCall(fn->body.as(), 1, {"nn.conv2d", "sigmoid"}); + } else if (name.find("dnnl.conv3d") != std::string::npos) { + std::vector op_list = ParsingOpList("conv3d", name); + call = GetRootCall(fn->body.as(), op_list.size() - 1, op_list); ICHECK(call->op.as()) << "Not op node"; - } else if (name == "dnnl.dense_bias") { - call = GetRootCall(fn->body.as(), 1, {"nn.dense", "add"}); + } else if (name.find("dnnl.dense") != std::string::npos) { + std::vector op_list = ParsingOpList("dense", name); + call = GetRootCall(fn->body.as(), op_list.size() - 1, op_list); ICHECK(call->op.as()) << "Not op node"; } else { LOG(FATAL) << "Unrecognized DNNL pattern: " << name; diff --git a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc index f9f1961e2697..6d5e5543cd40 100644 --- a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc +++ b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc @@ -26,6 +26,7 @@ #include #include +#include #include #include @@ -90,44 +91,95 @@ class DNNLJSONRuntime : public JSONRuntimeBase { private: // Build up the engine based on the input graph. + std::map layout_dict{ + {"NCW", tag::ncw}, {"OIW", tag::oiw}, {"GOIW", tag::goiw}, {"NCHW", tag::nchw}, + {"OIHW", tag::oihw}, {"GOIHW", tag::goihw}, {"NCDHW", tag::ncdhw}, {"OIDHW", tag::oidhw}, + {"GOIDHW", tag::goidhw}, {"IOHW", tag::iohw}, {"GIOHW", tag::giohw}, {"IODHW", tag::iodhw}, + {"GIODHW", tag::giodhw}, + }; + + std::map elt_name2algo{ + {"abs", dnnl::algorithm::eltwise_abs}, + {"exp", dnnl::algorithm::eltwise_exp}, + {"log", dnnl::algorithm::eltwise_log}, + {"sqrt", dnnl::algorithm::eltwise_sqrt}, + {"round", dnnl::algorithm::eltwise_round}, + {"logsumexp", dnnl::algorithm::eltwise_logsigmoid}, + {"nn.relu", dnnl::algorithm::eltwise_relu}, + {"nn.leaky_relu", dnnl::algorithm::eltwise_relu}, + {"tanh", dnnl::algorithm::eltwise_tanh}, + {"sigmoid", dnnl::algorithm::eltwise_logistic}, + {"clip", dnnl::algorithm::eltwise_clip}, + }; + + bool ParsingOpName(const std::string op_name, dnnl::primitive_attr attr) { + // Define RegExp. + std::regex bias_add_pat(".*_bias.*"); + std::regex relu_pat(".*_relu.*"); + std::regex tanh_pat(".*_tanh.*"); + std::regex sigmoid_pat(".*_sigmoid.*"); + + // Parsing post-ops. + dnnl::post_ops ops; + if (std::regex_match(op_name, relu_pat)) { + ops.append_eltwise(1.f, dnnl::algorithm::eltwise_relu, 0.f, 0.f); + } + if (std::regex_match(op_name, tanh_pat)) { + ops.append_eltwise(1.f, dnnl::algorithm::eltwise_tanh, 0.f, 0.f); + } + if (std::regex_match(op_name, sigmoid_pat)) { + ops.append_eltwise(1.f, dnnl::algorithm::eltwise_logistic, 0.f, 0.f); + } + attr.set_post_ops(ops); + + // Parsing bias_add. + return std::regex_match(op_name, bias_add_pat) ? true : false; + } + + dnnl::memory::dims TransformStr2Dims(std::vector strs, std::string str_name) { + dnnl::memory::dims out_dims; + if (str_name == "dilates") { + std::transform(strs.begin(), strs.end(), std::back_inserter(out_dims), + [](const std::string& str) { return std::stoi(str) - 1; }); + } else { + std::transform(strs.begin(), strs.end(), std::back_inserter(out_dims), + [](const std::string& str) { return std::stoi(str); }); + } + return out_dims; + } + void BuildEngine() { engine_ = dnnl::engine(dnnl::engine::kind::cpu, 0); stream_ = dnnl::stream(engine_); + std::regex conv_pat(".*conv[1-3]d.*"); + std::regex conv_tranpose_pat(".*conv[1-3]d_transpose.*"); + std::regex dense_pat(".*dense.*"); + std::regex max_pool_pat(".*max_pool[1-3]d"); + std::regex avg_pool_pat(".*avg_pool[1-3]d"); + // Build subgraph engine. for (size_t nid = 0; nid < nodes_.size(); ++nid) { const auto& node = nodes_[nid]; if (node.GetOpType() == "kernel") { ICHECK_EQ(node.GetOpType(), "kernel"); auto op_name = node.GetOpName(); - if ("nn.conv2d" == op_name) { - Conv2d(nid); - } else if ("dnnl.conv2d_relu" == op_name) { - Conv2d(nid, true, false, dnnl::algorithm::eltwise_relu); - } else if ("dnnl.conv2d_tanh" == op_name) { - Conv2d(nid, true, false, dnnl::algorithm::eltwise_tanh); - } else if ("dnnl.conv2d_sigmoid" == op_name) { - Conv2d(nid, true, false, dnnl::algorithm::eltwise_logistic); - } else if ("dnnl.conv2d_bias" == op_name) { - Conv2d(nid, false, true); - } else if ("dnnl.conv2d_bias_relu" == op_name) { - Conv2d(nid, true, true, dnnl::algorithm::eltwise_relu); - } else if ("dnnl.conv2d_bias_tanh" == op_name) { - Conv2d(nid, true, true, dnnl::algorithm::eltwise_tanh); - } else if ("dnnl.conv2d_bias_sigmoid" == op_name) { - Conv2d(nid, true, true, dnnl::algorithm::eltwise_logistic); - } else if ("nn.dense" == op_name) { + if (std::regex_match(op_name, conv_tranpose_pat)) { + Deconvolution(nid); + } else if (std::regex_match(op_name, conv_pat)) { + Convolution(nid); + } else if (std::regex_match(op_name, dense_pat)) { Dense(nid); - } else if ("dnnl.dense_bias" == op_name) { - Dense(nid, true); } else if ("nn.batch_norm" == op_name) { BatchNorm(nid); - } else if ("nn.relu" == op_name) { - Eltwise(nid, dnnl::algorithm::eltwise_relu); - } else if ("tanh" == op_name) { - Eltwise(nid, dnnl::algorithm::eltwise_tanh); - } else if ("sigmoid" == op_name) { - Eltwise(nid, dnnl::algorithm::eltwise_logistic); + } else if (std::regex_match(op_name, max_pool_pat)) { + Pooling(nid, dnnl::algorithm::pooling_max); + } else if (std::regex_match(op_name, avg_pool_pat)) { + Pooling(nid, dnnl::algorithm::pooling_avg); + } else if (elt_name2algo.count(op_name)) { + Eltwise(nid); + } else if ("nn.softmax" == op_name) { + Softmax(nid); } else if ("add" == op_name) { Binary(nid, dnnl::algorithm::binary_add); } else if ("multiply" == op_name) { @@ -166,73 +218,73 @@ class DNNLJSONRuntime : public JSONRuntimeBase { return entry_out_mem_[eid].first; } - void Conv2d(const size_t& nid, const bool has_elt = false, const bool has_bias = false, - dnnl::algorithm algo = dnnl::algorithm::eltwise_relu) { + void Convolution(const size_t& nid) { auto node = nodes_[nid]; + auto op_name = node.GetOpName(); + dnnl::primitive_attr attr; + bool has_bias = ParsingOpName(op_name, attr); // Setup attributes. auto data_entry = node.GetInputs()[0]; auto weight_entry = node.GetInputs()[1]; + JSONGraphNodeEntry out_entry(nid, 0); dnnl::memory::dims input_shape = nodes_[data_entry.id_].GetOpShape()[data_entry.index_]; dnnl::memory::dims weight_shape = nodes_[weight_entry.id_].GetOpShape()[weight_entry.index_]; + dnnl::memory::dims out_shape = nodes_[out_entry.id_].GetOpShape()[out_entry.index_]; + dnnl::memory::dim channels = + node.GetAttr>("channels")[0] != "" + ? std::stoi(node.GetAttr>("channels")[0]) + : out_shape[1]; std::vector str_strides = node.GetAttr>("strides"); std::vector str_dilates = node.GetAttr>("dilation"); std::vector str_padding = node.GetAttr>("padding"); + std::vector str_padding_l(str_padding.begin(), + str_padding.begin() + str_padding.size() / 2); + std::vector str_padding_r(str_padding.end() - str_padding.size() / 2, + str_padding.end()); dnnl::memory::dim groups = std::stoi(node.GetAttr>("groups")[0]); + std::string data_layout = node.GetAttr>("data_layout")[0]; + std::string kernel_layout = node.GetAttr>("kernel_layout")[0]; - dnnl::memory::dim N = input_shape[0], // batch size - IC = input_shape[1], // input channels - IH = input_shape[2], // input height - IW = input_shape[3], // input width - OC = weight_shape[0], // output channels - KH = weight_shape[2], // weight height - KW = weight_shape[3], // weight width - PW_L = std::stoi(str_padding[1]), // width padding: left - PW_R = std::stoi(str_padding[3]), // width padding: right - PH_L = std::stoi(str_padding[0]), // height padding: top - PH_R = std::stoi(str_padding[2]), // height padding: bottom - SH = std::stoi(str_strides[0]), // height-wise stride - SW = std::stoi(str_strides[1]), // weight-wise stride - DH = std::stoi(str_dilates[0]) - 1, // height-wise dilate - DW = std::stoi(str_dilates[1]) - 1, // weight-wise dilate - DKH = 1 + (KH - 1) * (DH + 1), // dilated weight height - DKW = 1 + (KW - 1) * (DW + 1), // dilated weight width - OH = (IH - DKH + PH_L + PH_R) / SH + 1, // output height - OW = (IW - DKW + PW_L + PW_R) / SW + 1; // output width + // Check layout. + if (layout_dict.find(data_layout) == layout_dict.end() || + layout_dict.find(kernel_layout) == layout_dict.end()) { + LOG(FATAL) << "Unsupported layout for conv: " << data_layout << " " << kernel_layout; + } // Memory shapes. - dnnl::memory::dims src_dims = {N, IC, IH, IW}; - dnnl::memory::dims weights_dims = {OC, IC, KH, KW}; + dnnl::memory::dims src_dims = input_shape; // {N, IC, ID, IH, IW} + dnnl::memory::dims weights_dims = weight_shape; // {OC, IC, KD, KH, KW} if (groups > 1) { - weights_dims = {groups, 1, IC / groups, KH, KW}; + weights_dims = {groups, channels / groups, input_shape[1] / groups}; + weights_dims.insert(weights_dims.end(), weight_shape.begin() + 2, weight_shape.end()); + kernel_layout.insert(0, "G"); } - dnnl::memory::dims bias_dims = {OC}; - dnnl::memory::dims dst_dims = {N, OC, OH, OW}; - dnnl::memory::dims strides_dims = {SH, SW}; - dnnl::memory::dims dilates_dims = {DH, DW}; - dnnl::memory::dims padding_dims_l = {PH_L, PW_L}; - dnnl::memory::dims padding_dims_r = {PH_R, PW_R}; + dnnl::memory::dims bias_dims = {channels}; + dnnl::memory::dims dst_dims = out_shape; // {N, OC, OD, OH, OW} + dnnl::memory::dims strides_dims = TransformStr2Dims(str_strides, "strides"); + dnnl::memory::dims dilates_dims = TransformStr2Dims(str_dilates, "dilates"); + dnnl::memory::dims padding_dims_l = TransformStr2Dims(str_padding_l, "padding"); + dnnl::memory::dims padding_dims_r = TransformStr2Dims(str_padding_r, "padding"); // Memory descriptions. - auto conv_src_md = dnnl::memory::desc(src_dims, dt::f32, tag::any); - auto conv_weights_md = dnnl::memory::desc(weights_dims, dt::f32, tag::any); + auto conv_src_md = dnnl::memory::desc(src_dims, dt::f32, layout_dict[data_layout]); + auto conv_weights_md = dnnl::memory::desc(weights_dims, dt::f32, layout_dict[kernel_layout]); auto conv_bias_md = dnnl::memory::desc(bias_dims, dt::f32, tag::any); - auto conv_dst_md = dnnl::memory::desc(dst_dims, dt::f32, tag::nchw); + auto conv_dst_md = dnnl::memory::desc(dst_dims, dt::f32, layout_dict[data_layout]); // Covn2d description. - auto conv_desc = dnnl::convolution_forward::desc( - dnnl::prop_kind::forward_inference, dnnl::algorithm::convolution_direct, conv_src_md, - conv_weights_md, conv_bias_md, conv_dst_md, strides_dims, dilates_dims, padding_dims_l, - padding_dims_r); - - // Enable elementwise post-ops - dnnl::primitive_attr attr; - if (has_elt) { - dnnl::post_ops ops; - ops.append_eltwise(1.f, algo, 0.f, 0.f); - attr.set_post_ops(ops); - } - + auto conv_desc = + has_bias ? dnnl::convolution_forward::desc( + dnnl::prop_kind::forward_inference, dnnl::algorithm::convolution_direct, + conv_src_md, conv_weights_md, conv_bias_md, conv_dst_md, strides_dims, + dilates_dims, padding_dims_l, padding_dims_r) + : dnnl::convolution_forward::desc(dnnl::prop_kind::forward_inference, + dnnl::algorithm::convolution_direct, conv_src_md, + conv_weights_md, conv_dst_md, strides_dims, + dilates_dims, padding_dims_l, padding_dims_r); + + // Enable elementwise post-ops. auto conv2d_prim_desc = dnnl::convolution_forward::primitive_desc(conv_desc, attr, engine_); // Push to the network. @@ -240,53 +292,163 @@ class DNNLJSONRuntime : public JSONRuntimeBase { net_.push_back(conv); // Data memory. - ICHECK_EQ(node.GetAttr>("data_layout")[0], "NCHW"); - auto conv2d_src_memory = BindDNNLMemory(data_entry, {src_dims, dt::f32, tag::nchw}); + auto conv2d_src_memory = BindDNNLMemory(data_entry, conv_src_md); // Weight memory. - ICHECK_EQ(node.GetAttr>("kernel_layout")[0], "OIHW"); - auto conv2d_weights_memory = BindDNNLMemory( - weight_entry, {weights_dims, dt::f32, (groups > 1) ? tag::goihw : tag::oihw}); + auto conv2d_weights_memory = BindDNNLMemory(weight_entry, conv_weights_md); + + // Output memory. + auto conv2d_dst_memory = BindDNNLMemory(out_entry, conv2d_prim_desc.dst_desc()); // Bias memory. auto conv2d_bias_memory = dnnl::memory({bias_dims, dt::f32, tag::x}, engine_); if (has_bias) { auto bias_entry = node.GetInputs()[2]; BindDNNLMemory(bias_entry, conv2d_bias_memory); + + // Bind memory buffers. + net_args_.push_back({{DNNL_ARG_SRC, conv2d_src_memory}, + {DNNL_ARG_WEIGHTS, conv2d_weights_memory}, + {DNNL_ARG_BIAS, conv2d_bias_memory}, + {DNNL_ARG_DST, conv2d_dst_memory}}); } else { - float bias[OC] = {0}; - write_to_dnnl_memory(bias, conv2d_bias_memory, OC * sizeof(float)); + // Bind memory buffers. + net_args_.push_back({{DNNL_ARG_SRC, conv2d_src_memory}, + {DNNL_ARG_WEIGHTS, conv2d_weights_memory}, + {DNNL_ARG_DST, conv2d_dst_memory}}); } + } - // Output memory. + void Deconvolution(const size_t& nid) { + auto node = nodes_[nid]; + auto op_name = node.GetOpName(); + dnnl::primitive_attr attr; + bool has_bias = ParsingOpName(op_name, attr); + + // Setup attributes. + auto data_entry = node.GetInputs()[0]; + auto weight_entry = node.GetInputs()[1]; JSONGraphNodeEntry out_entry(nid, 0); - auto conv2d_dst_memory = BindDNNLMemory(out_entry, conv2d_prim_desc.dst_desc()); + dnnl::memory::dims input_shape = nodes_[data_entry.id_].GetOpShape()[data_entry.index_]; + dnnl::memory::dims weight_shape = nodes_[weight_entry.id_].GetOpShape()[weight_entry.index_]; + dnnl::memory::dims out_shape = nodes_[out_entry.id_].GetOpShape()[out_entry.index_]; + dnnl::memory::dim channels = + node.GetAttr>("channels")[0] != "" + ? std::stoi(node.GetAttr>("channels")[0]) + : out_shape[1]; + std::vector str_strides = node.GetAttr>("strides"); + std::vector str_dilates = node.GetAttr>("dilation"); + std::vector str_padding = node.GetAttr>("padding"); + std::vector str_padding_l(str_padding.begin(), + str_padding.begin() + str_padding.size() / 2); + std::vector str_padding_r(str_padding.end() - str_padding.size() / 2, + str_padding.end()); + dnnl::memory::dim groups = std::stoi(node.GetAttr>("groups")[0]); + std::string data_layout = node.GetAttr>("data_layout")[0]; + std::string kernel_layout = node.GetAttr>("kernel_layout")[0]; - // Bind memory buffers. - net_args_.push_back({{DNNL_ARG_SRC, conv2d_src_memory}, - {DNNL_ARG_WEIGHTS, conv2d_weights_memory}, - {DNNL_ARG_BIAS, conv2d_bias_memory}, - {DNNL_ARG_DST, conv2d_dst_memory}}); + // Check layout. + if (layout_dict.find(data_layout) == layout_dict.end() || + layout_dict.find(kernel_layout) == layout_dict.end()) { + LOG(FATAL) << "Unsupported layout: " << data_layout << " " << kernel_layout; + } + + // Memory shapes. + dnnl::memory::dims src_dims = input_shape; // {N, IC, ID, IH, IW} + dnnl::memory::dims weights_dims = weight_shape; // {OC, IC, KD, KH, KW} + + // Check weight shape, transform to `OIHW` + if (weights_dims[0] == src_dims[1] && weights_dims[1] == channels) { + std::swap(weights_dims[0], weights_dims[1]); + } + if (kernel_layout == "OIDHW") { + kernel_layout = "IODHW"; + } + if (groups > 1) { + weights_dims = {groups, channels / groups, input_shape[1] / groups}; + weights_dims.insert(weights_dims.end(), weight_shape.begin() + 2, weight_shape.end()); + kernel_layout.insert(0, "G"); + } + dnnl::memory::dims bias_dims = {channels}; + dnnl::memory::dims dst_dims = out_shape; // {N, OC, OD, OH, OW} + dnnl::memory::dims strides_dims = TransformStr2Dims(str_strides, "strides"); + dnnl::memory::dims dilates_dims = TransformStr2Dims(str_dilates, "dilates"); + dnnl::memory::dims padding_dims_l = TransformStr2Dims(str_padding_l, "padding"); + dnnl::memory::dims padding_dims_r = TransformStr2Dims(str_padding_r, "padding"); + + // Memory descriptions. + auto deconv_src_md = dnnl::memory::desc(src_dims, dt::f32, layout_dict[data_layout]); + auto deconv_weights_md = dnnl::memory::desc(weights_dims, dt::f32, layout_dict[kernel_layout]); + auto deconv_bias_md = dnnl::memory::desc(bias_dims, dt::f32, tag::any); + auto deconv_dst_md = dnnl::memory::desc(dst_dims, dt::f32, layout_dict[data_layout]); + + // Transposed covn2d description. + auto deconv_desc = + has_bias ? dnnl::deconvolution_forward::desc( + dnnl::prop_kind::forward_inference, dnnl::algorithm::deconvolution_direct, + deconv_src_md, deconv_weights_md, deconv_bias_md, deconv_dst_md, + strides_dims, dilates_dims, padding_dims_l, padding_dims_r) + : dnnl::deconvolution_forward::desc( + dnnl::prop_kind::forward_inference, dnnl::algorithm::deconvolution_direct, + deconv_src_md, deconv_weights_md, deconv_dst_md, strides_dims, dilates_dims, + padding_dims_l, padding_dims_r); + + // Enable elementwise post-ops. + auto deconv2d_prim_desc = + dnnl::deconvolution_forward::primitive_desc(deconv_desc, attr, engine_); + + // Push to the network. + auto deconv = dnnl::deconvolution_forward(deconv2d_prim_desc); + net_.push_back(deconv); + + // Data memory. + auto deconv2d_src_memory = BindDNNLMemory(data_entry, deconv_src_md); + + // Weight memory. + auto deconv2d_weights_memory = BindDNNLMemory(weight_entry, deconv_weights_md); + + // Output memory. + auto deconv2d_dst_memory = BindDNNLMemory(out_entry, deconv2d_prim_desc.dst_desc()); + + // Bias memory. + auto deconv2d_bias_memory = dnnl::memory({bias_dims, dt::f32, tag::x}, engine_); + if (has_bias) { + auto bias_entry = node.GetInputs()[2]; + BindDNNLMemory(bias_entry, deconv2d_bias_memory); + + // Bind memory buffers. + net_args_.push_back({{DNNL_ARG_SRC, deconv2d_src_memory}, + {DNNL_ARG_WEIGHTS, deconv2d_weights_memory}, + {DNNL_ARG_BIAS, deconv2d_bias_memory}, + {DNNL_ARG_DST, deconv2d_dst_memory}}); + } else { + // Bind memory buffers. + net_args_.push_back({{DNNL_ARG_SRC, deconv2d_src_memory}, + {DNNL_ARG_WEIGHTS, deconv2d_weights_memory}, + {DNNL_ARG_DST, deconv2d_dst_memory}}); + } } - void Dense(const size_t& nid, const bool has_bias = false) { + void Dense(const size_t& nid) { auto node = nodes_[nid]; + auto op_name = node.GetOpName(); + dnnl::primitive_attr attr; + bool has_bias = ParsingOpName(op_name, attr); // Setup attributes. auto data_entry = node.GetInputs()[0]; auto weight_entry = node.GetInputs()[1]; + JSONGraphNodeEntry out_entry(nid, 0); dnnl::memory::dims input_shape = nodes_[data_entry.id_].GetOpShape()[data_entry.index_]; dnnl::memory::dims weight_shape = nodes_[weight_entry.id_].GetOpShape()[weight_entry.index_]; - - dnnl::memory::dim B = input_shape[0], // batch size - IC = input_shape[1], // input channels - OC = weight_shape[0]; // output channels + dnnl::memory::dims out_shape = nodes_[out_entry.id_].GetOpShape()[out_entry.index_]; + dnnl::memory::dim OC = out_shape[1]; // Memory shapes. - dnnl::memory::dims data_dims = {B, IC}; - dnnl::memory::dims weight_dims = {OC, IC}; + dnnl::memory::dims data_dims = input_shape; + dnnl::memory::dims weight_dims = weight_shape; dnnl::memory::dims bias_dims = {OC}; - dnnl::memory::dims out_dims = {B, OC}; + dnnl::memory::dims out_dims = out_shape; // Memory descriptions. auto data_md = dnnl::memory::desc({data_dims, dt::f32, tag::nc}); @@ -297,7 +459,9 @@ class DNNLJSONRuntime : public JSONRuntimeBase { // Dense description. auto dense_desc = dnnl::inner_product_forward::desc(dnnl::prop_kind::forward_inference, data_md, weight_md, bias_md, dst_md); - auto dense_prim_desc = dnnl::inner_product_forward::primitive_desc(dense_desc, engine_); + + // Enable elementwise post-ops. + auto dense_prim_desc = dnnl::inner_product_forward::primitive_desc(dense_desc, attr, engine_); auto dense = dnnl::inner_product_forward(dense_prim_desc); net_.push_back(dense); @@ -317,7 +481,6 @@ class DNNLJSONRuntime : public JSONRuntimeBase { } // Output memory. - JSONGraphNodeEntry out_entry(nid, 0); auto dst_memory = BindDNNLMemory(out_entry, dense_prim_desc.dst_desc()); net_args_.push_back({{DNNL_ARG_SRC, data_memory}, @@ -368,15 +531,85 @@ class DNNLJSONRuntime : public JSONRuntimeBase { {DNNL_ARG_VARIANCE, variance_memory}}); } - void Eltwise(const size_t& nid, dnnl::algorithm algo) { + void Pooling(const size_t& nid, dnnl::algorithm algo) { + auto node = nodes_[nid]; + + // Setup attributes. + auto data_entry = node.GetInputs()[0]; + JSONGraphNodeEntry out_entry(nid, 0); + dnnl::memory::dims input_shape = nodes_[data_entry.id_].GetOpShape()[data_entry.index_]; + dnnl::memory::dims out_shape = nodes_[out_entry.id_].GetOpShape()[out_entry.index_]; + std::vector str_kernel = node.GetAttr>("pool_size"); + std::vector str_strides = node.GetAttr>("strides"); + std::vector str_padding = node.GetAttr>("padding"); + std::vector str_padding_l(str_padding.begin(), + str_padding.begin() + str_padding.size() / 2); + std::vector str_padding_r(str_padding.end() - str_padding.size() / 2, + str_padding.end()); + std::vector str_dilates = node.GetAttr>("dilation"); + std::string layout = node.GetAttr>("layout")[0]; + + // Check layout. + if (layout_dict.find(layout) == layout_dict.end()) { + LOG(FATAL) << "Unsupported layout for pooling: " << layout; + } + + // Attributes related to AvgPool + if (algo == dnnl::algorithm::pooling_avg) { + int int_countpad = std::stoi(node.GetAttr>("count_include_pad")[0]); + bool count_include_pad = int_countpad != 0 ? true : false; + algo = count_include_pad ? dnnl::algorithm::pooling_avg_include_padding + : dnnl::algorithm::pooling_avg_exclude_padding; + } + + dnnl::memory::dims src_dims = input_shape; + dnnl::memory::dims dst_dims = out_shape; + dnnl::memory::dims kernel_dims = TransformStr2Dims(str_kernel, "kernel"); + dnnl::memory::dims strides_dims = TransformStr2Dims(str_strides, "strides"); + dnnl::memory::dims dilates_dims = TransformStr2Dims(str_dilates, "dilates"); + dnnl::memory::dims padding_dims_l = TransformStr2Dims(str_padding_l, "padding"); + dnnl::memory::dims padding_dims_r = TransformStr2Dims(str_padding_r, "padding"); + + // Memory descriptions. + auto pool_src_md = dnnl::memory::desc(src_dims, dt::f32, layout_dict[layout]); + auto pool_dst_md = dnnl::memory::desc(dst_dims, dt::f32, tag::any); + + // Pooling description. + auto pool_desc = dnnl::pooling_forward::desc(dnnl::prop_kind::forward_inference, algo, + pool_src_md, pool_dst_md, strides_dims, + kernel_dims, padding_dims_l, padding_dims_r); + + auto pool_prim_desc = dnnl::pooling_forward::primitive_desc(pool_desc, engine_, true); + auto pool = dnnl::pooling_forward(pool_prim_desc); + net_.push_back(pool); + + // Memories. + auto pool2d_src_memory = BindDNNLMemory(data_entry, pool_src_md); + + auto pool2d_dst_memory = BindDNNLMemory(out_entry, pool_prim_desc.dst_desc()); + + // Bind memory buffers. + net_args_.push_back({{DNNL_ARG_SRC, pool2d_src_memory}, {DNNL_ARG_DST, pool2d_dst_memory}}); + } + + void Eltwise(const size_t& nid) { auto node = nodes_[nid]; + auto op_name = node.GetOpName(); + auto algo = elt_name2algo[op_name]; auto data_entry = node.GetInputs()[0]; dnnl::memory::dims shape = nodes_[data_entry.id_].GetOpShape()[data_entry.index_]; dnnl::memory::desc data_md = GenDNNLMemDescByShape(shape, dt::f32); + float alpha = 0., beta = 0.; + if (op_name == "clip") { + alpha = std::stof(node.GetAttr>("a_min")[0]); + beta = std::stof(node.GetAttr>("a_max")[0]); + } else if (op_name == "nn.leaky_relu") { + alpha = std::stof(node.GetAttr>("alpha")[0]); + } auto elt_desc = - dnnl::eltwise_forward::desc(dnnl::prop_kind::forward_inference, algo, data_md, 0); + dnnl::eltwise_forward::desc(dnnl::prop_kind::forward_inference, algo, data_md, alpha, beta); auto elt_prim_desc = dnnl::eltwise_forward::primitive_desc(elt_desc, engine_); ICHECK(data_md == elt_prim_desc.dst_desc()); @@ -390,6 +623,32 @@ class DNNLJSONRuntime : public JSONRuntimeBase { net_args_.push_back({{DNNL_ARG_SRC, data_memory}, {DNNL_ARG_DST, out_memory}}); } + void Softmax(const size_t& nid) { + auto node = nodes_[nid]; + + auto data_entry = node.GetInputs()[0]; + dnnl::memory::dims shape = nodes_[data_entry.id_].GetOpShape()[data_entry.index_]; + int axis = std::stoi(node.GetAttr>("axis")[0]); + if (axis < 0) { + axis = shape.size() + axis; + } + dnnl::memory::desc data_md = GenDNNLMemDescByShape(shape, dt::f32); + + auto softmax_desc = + dnnl::softmax_forward::desc(dnnl::prop_kind::forward_inference, data_md, axis); + auto softmax_prim_desc = dnnl::softmax_forward::primitive_desc(softmax_desc, engine_); + ICHECK(data_md == softmax_prim_desc.dst_desc()); + + auto softmax = dnnl::softmax_forward(softmax_prim_desc); + net_.push_back(softmax); + + auto data_memory = BindDNNLMemory(data_entry, data_md); + JSONGraphNodeEntry out_entry(nid, 0); + auto out_memory = BindDNNLMemory(out_entry, data_md); + + net_args_.push_back({{DNNL_ARG_SRC, data_memory}, {DNNL_ARG_DST, out_memory}}); + } + void Binary(const size_t& nid, dnnl::algorithm algo) { auto node = nodes_[nid]; diff --git a/tests/python/contrib/test_dnnl.py b/tests/python/contrib/test_dnnl.py index 7adf3e40ad33..4d1972d6a3b0 100755 --- a/tests/python/contrib/test_dnnl.py +++ b/tests/python/contrib/test_dnnl.py @@ -14,7 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import numpy as np import pytest import itertools import tvm @@ -22,6 +21,7 @@ from tvm import relay from tvm.relay.op.contrib import dnnl import tvm.testing +import numpy as np has_dnnl_codegen = pytest.mark.skipif( not tvm.get_global_func("relay.ext.dnnl", True), reason="DNNL codegen not available" @@ -65,6 +65,7 @@ def check_dnnl_used(mod): result_key = mode + ("_dnnl" if use_dnnl else "") if use_dnnl: mod = dnnl.partition_for_dnnl(mod, params) + check_dnnl_used(mod) with tvm.transform.PassContext(opt_level=3): func = relay.create_executor(mode, mod=mod, device=dev, target=target).evaluate() if run_module: @@ -99,6 +100,79 @@ def run_and_verify_func(config, run_module, target="llvm", dtype="float32"): run_and_verify(f, input_dict, params, target, run_module) +def get_conv1d( + x_shape=((1, 3, 224)), + k_shape=(10, 3, 3), + groups=1, + padding=(1, 1), + strides=(1), + dilation=(1), + channels=None, + activation=None, + dtype="float32", +): + x = relay.var("x", shape=(x_shape), dtype=dtype) + kernel = relay.var("kernel", shape=(k_shape), dtype=dtype) + out = relay.nn.conv1d( + x, + kernel, + kernel_size=k_shape[2:3], + groups=groups, + padding=padding, + strides=strides, + dilation=dilation, + channels=k_shape[0], + ) + dic = {"x": x_shape, "kernel": k_shape} + param_lst = ["kernel"] + + if activation == "relu": + return relay.nn.relu(out), dic, param_lst + elif activation == "tanh": + return relay.tanh(out), dic, param_lst + elif activation == "sigmoid": + return relay.sigmoid(out), dic, param_lst + else: + return out, dic, param_lst + + +def get_conv1d_bias(x_shape=(1, 3, 224), k_shape=(10, 3, 3), activation=None, dtype="float32"): + conv, dic, param_lst = get_conv1d(x_shape=x_shape, k_shape=k_shape, dtype=dtype) + bias = relay.var("bias", shape=(k_shape[0],), dtype=dtype) + out = relay.nn.bias_add(conv, bias) + dic["bias"] = (k_shape[0],) + param_lst += ["bias"] + + if activation == "relu": + return relay.nn.relu(out), dic, param_lst + elif activation == "tanh": + return relay.tanh(out), dic, param_lst + elif activation == "sigmoid": + return relay.sigmoid(out), dic, param_lst + else: + return out, dic, param_lst + + +def get_conv1d_bias_bn_relu(x_shape=(1, 3, 224), k_shape=(10, 3, 3), dtype="float32"): + conv1d_bias, dic, param_lst = get_conv1d_bias(x_shape, k_shape, dtype=dtype) + beta = relay.const(np.zeros(k_shape[0]).astype(dtype)) + gamma = relay.const(np.ones(k_shape[0]).astype(dtype)) + moving_mean = relay.const(np.zeros(k_shape[0]).astype(dtype)) + moving_var = relay.const(np.ones(k_shape[0]).astype(dtype)) + conv1d_bias_bn, _, _ = relay.nn.batch_norm( + conv1d_bias, + gamma=gamma, + beta=beta, + moving_mean=moving_mean, + moving_var=moving_var, + axis=1, + center=True, + scale=True, + epsilon=1e-5, + ) + return relay.nn.relu(conv1d_bias_bn), dic, param_lst + + def get_conv2d( x_shape=(1, 32, 8, 8), k_shape=(16, 32, 3, 3), @@ -134,6 +208,39 @@ def get_conv2d( return out, dic, param_lst +def get_conv2d_transpose( + x_shape=(1, 32, 8, 8), + k_shape=(32, 16, 3, 3), + groups=1, + padding=(0, 0), + strides=(1, 1), + activation=None, + dtype="float32", +): + x = relay.var("x", shape=(x_shape), dtype=dtype) + kernel = relay.var("kernel", shape=(k_shape), dtype=dtype) + out = relay.nn.conv2d_transpose( + x, + kernel, + channels=k_shape[1], + kernel_size=k_shape[2:4], + groups=groups, + padding=padding, + strides=strides, + ) + dic = {"x": x_shape, "kernel": k_shape} + param_lst = ["kernel"] + + if activation == "relu": + return relay.nn.relu(out), dic, param_lst + elif activation == "tanh": + return relay.tanh(out), dic, param_lst + elif activation == "sigmoid": + return relay.sigmoid(out), dic, param_lst + else: + return out, dic, param_lst + + def get_conv2d_weights_const( x_shape=(1, 32, 8, 8), k_shape=(16, 32, 3, 3), @@ -179,6 +286,25 @@ def get_conv2d_bias( return out, dic, param_lst +def get_conv2d_transpose_bias( + x_shape=(1, 32, 8, 8), k_shape=(32, 16, 3, 3), activation=None, dtype="float32" +): + conv, dic, param_lst = get_conv2d_transpose(x_shape=x_shape, k_shape=k_shape, dtype=dtype) + bias = relay.var("bias", shape=(k_shape[1],), dtype=dtype) + out = relay.nn.bias_add(conv, bias) + dic["bias"] = (k_shape[1],) + param_lst += ["bias"] + + if activation == "relu": + return relay.nn.relu(out), dic, param_lst + elif activation == "tanh": + return relay.tanh(out), dic, param_lst + elif activation == "sigmoid": + return relay.sigmoid(out), dic, param_lst + else: + return out, dic, param_lst + + def get_conv2d_bias_bn_relu(x_shape=(1, 32, 8, 8), k_shape=(16, 32, 3, 3), dtype="float32"): conv2d_bias, dic, param_lst = get_conv2d_bias(x_shape, k_shape, dtype=dtype) beta = relay.const(np.zeros(k_shape[0]).astype(dtype)) @@ -199,6 +325,118 @@ def get_conv2d_bias_bn_relu(x_shape=(1, 32, 8, 8), k_shape=(16, 32, 3, 3), dtype return relay.nn.relu(conv2d_bias_bn), dic, param_lst +def get_conv3d( + x_shape=(1, 32, 8, 8, 8), + k_shape=(16, 32, 3, 3, 3), + groups=1, + padding=(0, 0, 0), + strides=(1, 1, 1), + dilation=(1, 1, 1), + activation=None, + dtype="float32", +): + x = relay.var("x", shape=(x_shape), dtype=dtype) + kernel = relay.var("kernel", shape=(k_shape), dtype=dtype) + out = relay.nn.conv3d( + x, + kernel, + channels=k_shape[0], + kernel_size=k_shape[2:], + groups=groups, + padding=padding, + strides=strides, + dilation=dilation, + ) + dic = {"x": x_shape, "kernel": k_shape} + param_lst = ["kernel"] + + if activation == "relu": + return relay.nn.relu(out), dic, param_lst + elif activation == "tanh": + return relay.tanh(out), dic, param_lst + elif activation == "sigmoid": + return relay.sigmoid(out), dic, param_lst + else: + return out, dic, param_lst + + +def get_conv3d_transpose( + x_shape=(1, 32, 8, 8, 8), + k_shape=(32, 16, 3, 3, 3), + groups=1, + padding=(0, 0, 0), + strides=(1, 1, 1), + output_padding=(0, 0, 0), + activation=None, + dtype="float32", + data_layout="NCDHW", + kernel_layout="OIDHW", +): + x = relay.var("x", shape=(x_shape), dtype=dtype) + kernel = relay.var("kernel", shape=(k_shape), dtype=dtype) + out = relay.nn.conv3d_transpose( + x, + kernel, + channels=k_shape[1], + kernel_size=k_shape[2:5], + groups=groups, + padding=padding, + strides=strides, + output_padding=output_padding, + data_layout=data_layout, + kernel_layout=kernel_layout, + ) + dic = {"x": x_shape, "kernel": k_shape} + param_lst = ["kernel"] + + if activation == "relu": + return relay.nn.relu(out), dic, param_lst + elif activation == "tanh": + return relay.tanh(out), dic, param_lst + elif activation == "sigmoid": + return relay.sigmoid(out), dic, param_lst + else: + return out, dic, param_lst + + +def get_conv3d_bias( + x_shape=(1, 32, 8, 8, 8), k_shape=(16, 32, 3, 3, 3), activation=None, dtype="float32" +): + conv, dic, param_lst = get_conv3d(x_shape=x_shape, k_shape=k_shape, dtype=dtype) + bias = relay.var("bias", shape=(k_shape[0],), dtype=dtype) + out = relay.nn.bias_add(conv, bias) + dic["bias"] = (k_shape[0],) + param_lst += ["bias"] + + if activation == "relu": + return relay.nn.relu(out), dic, param_lst + elif activation == "tanh": + return relay.tanh(out), dic, param_lst + elif activation == "sigmoid": + return relay.sigmoid(out), dic, param_lst + else: + return out, dic, param_lst + + +def get_conv3d_transpose_bias( + x_shape=(1, 32, 8, 8, 8), k_shape=(32, 16, 3, 3, 3), activation=None, dtype="float32" +): + conv, dic, param_lst = get_conv3d_transpose(x_shape=x_shape, k_shape=k_shape, dtype=dtype) + bias = relay.var("bias", shape=(k_shape[1],), dtype=dtype) + out = relay.nn.bias_add(conv, bias) + dic["bias"] = (k_shape[1],) + param_lst += ["bias"] + + if activation == "relu": + return relay.nn.relu(out), dic, param_lst + elif activation == "tanh": + return relay.tanh(out), dic, param_lst + elif activation == "sigmoid": + return relay.sigmoid(out), dic, param_lst + else: + return out, dic, param_lst + + def get_dense(x_shape=(1, 16), k_shape=(32, 16), activation=None, dtype="float32"): x = relay.var("x", shape=(x_shape), dtype=dtype) kernel = relay.var("kernel", shape=(k_shape), dtype=dtype) @@ -249,14 +487,20 @@ def get_graph(): run_and_verify_func(get_graph(), run_module=run_module, dtype=dtype) -def test_unary(run_module): +def test_elementwise(run_module, dtype="float32"): def get_graph(op, x_shape=(1, 8, 3, 3)): - x = relay.var("x", shape=(x_shape), dtype="float32") + x = relay.var("x", shape=(x_shape), dtype=dtype) out = op(x) f = tvm.IRModule.from_expr(out) return f, {"x": x_shape}, [] for op in [ + relay.abs, + relay.exp, + relay.log, + relay.sqrt, + relay.round, + relay.logsumexp, relay.nn.relu, relay.tanh, relay.sigmoid, @@ -264,6 +508,62 @@ def get_graph(op, x_shape=(1, 8, 3, 3)): run_and_verify_func(get_graph(op), run_module=run_module) +def test_clip(run_module, dtype="float32"): + def get_graph(x_shape=(1, 8, 3, 3)): + x = relay.var("x", shape=(x_shape), dtype=dtype) + out = relay.clip(x, a_min=-0.2, a_max=0.4) + f = tvm.IRModule.from_expr(out) + return f, {"x": x_shape}, [] + + run_and_verify_func(get_graph(), run_module=run_module) + + +def test_leaky_relu(run_module, dtype="float32"): + def get_graph(x_shape=(1, 8, 3, 3)): + x = relay.var("x", shape=(x_shape), dtype=dtype) + out = relay.nn.leaky_relu(x, alpha=0.1) + f = tvm.IRModule.from_expr(out) + return f, {"x": x_shape}, [] + + run_and_verify_func(get_graph(), run_module=run_module) + + +def test_softmax(run_module, dtype="float32"): + def get_graph(x_shape, axis): + x = relay.var("x", shape=(x_shape), dtype=dtype) + out = relay.nn.softmax(x, axis=axis) + f = tvm.IRModule.from_expr(out) + return f, {"x": x_shape}, [] + + run_and_verify_func(get_graph((1, 1000), axis=1), run_module=run_module) + run_and_verify_func(get_graph((1, 1000), axis=-1), run_module=run_module) + run_and_verify_func(get_graph((1, 3, 4), axis=-2), run_module=run_module) + run_and_verify_func(get_graph((1, 3, 4), axis=1), run_module=run_module) + + +def test_conv1d(run_module, dtype="float32"): + conv1d, dic, param_lst = get_conv1d(channels=10, dtype=dtype) + conv1d = tvm.IRModule.from_expr(conv1d) + config = conv1d, dic, param_lst + run_and_verify_func(config, run_module=run_module, dtype=dtype) + + +def test_conv1d_pattern(run_module, dtype="float32"): + x_shape = (1, 3, 224) + k_shape = (10, 3, 3) + activation_lst = [None, "relu", "tanh", "sigmoid"] + for a in activation_lst: + conv1d, dic, param_lst = get_conv1d(x_shape, k_shape, activation=a, dtype=dtype) + conv1d = tvm.IRModule.from_expr(conv1d) + config = conv1d, dic, param_lst + run_and_verify_func(config, run_module=run_module, dtype=dtype) + + conv1d_bias, dic, param_lst = get_conv1d_bias(x_shape, k_shape, activation=a, dtype=dtype) + conv1d_bias = tvm.IRModule.from_expr(conv1d_bias) + config = conv1d_bias, dic, param_lst + run_and_verify_func(config, run_module=run_module, dtype=dtype) + + def test_conv2d(run_module, dtype="float32"): x_shape = (1, 32, 8, 8) for k_shape, groups in [((16, 32, 3, 3), 1), ((32, 1, 3, 3), 32)]: @@ -314,6 +614,90 @@ def test_conv2d_pattern(run_module, dtype="float32"): run_and_verify_func(config, run_module=run_module, dtype=dtype) +def test_conv2d_transpose(run_module, dtype="float32"): + for padding in [(0, 0), (1, 1)]: + for strides in [(1, 1), (2, 2)]: + conv2d_transpose, dic, param_lst = get_conv2d_transpose( + padding=padding, strides=strides, dtype=dtype + ) + conv2d_transpose = tvm.IRModule.from_expr(conv2d_transpose) + config = conv2d_transpose, dic, param_lst + run_and_verify_func(config, run_module=run_module, dtype=dtype) + + +def test_conv2d_transpose_pattern(run_module, dtype="float32"): + activation_lst = [None, "relu", "tanh", "sigmoid"] + for a in activation_lst: + conv2d, dic, param_lst = get_conv2d_transpose(activation=a, dtype=dtype) + conv2d = tvm.IRModule.from_expr(conv2d) + config = conv2d, dic, param_lst + run_and_verify_func(config, run_module=run_module, dtype=dtype) + + conv2d_bias, dic, param_lst = get_conv2d_transpose_bias(activation=a, dtype=dtype) + conv2d_bias = tvm.IRModule.from_expr(conv2d_bias) + config = conv2d_bias, dic, param_lst + run_and_verify_func(config, run_module=run_module, dtype=dtype) + + +def test_conv3d(run_module, dtype="float32"): + conv3d, dic, param_lst = get_conv3d(dtype=dtype) + conv3d = tvm.IRModule.from_expr(conv3d) + config = conv3d, dic, param_lst + run_and_verify_func(config, run_module=run_module, dtype=dtype) + + conv3d, dic, param_lst = get_conv3d(padding=(0, 0, 0, 1, 1, 1), dtype=dtype) + conv3d = tvm.IRModule.from_expr(conv3d) + config = conv3d, dic, param_lst + run_and_verify_func(config, run_module=run_module, dtype=dtype) + + +def test_conv3d_pattern(run_module, dtype="float32"): + activation_lst = [None, "relu", "tanh", "sigmoid"] + for a in activation_lst: + conv3d, dic, param_lst = get_conv3d(activation=a, dtype=dtype) + conv3d = tvm.IRModule.from_expr(conv3d) + config = conv3d, dic, param_lst + run_and_verify_func(config, run_module=run_module, dtype=dtype) + + conv3d_bias, dic, param_lst = get_conv3d_bias(activation=a, dtype=dtype) + conv3d_bias = tvm.IRModule.from_expr(conv3d_bias) + config = conv3d_bias, dic, param_lst + run_and_verify_func(config, run_module=run_module, dtype=dtype) + + +def test_conv3d_transpose(run_module, dtype="float32"): + conv3d_transpose, dic, param_lst = get_conv3d_transpose(dtype=dtype) + conv3d_transpose = tvm.IRModule.from_expr(conv3d_transpose) + config = conv3d_transpose, dic, param_lst + run_and_verify_func(config, run_module=run_module, dtype=dtype) + + conv3d_transpose, dic, param_lst = get_conv3d_transpose(strides=(2, 2, 2), dtype=dtype) + conv3d_transpose = tvm.IRModule.from_expr(conv3d_transpose) + config = conv3d_transpose, dic, param_lst + run_and_verify_func(config, run_module=run_module, dtype=dtype) + + conv3d_transpose, dic, param_lst = get_conv3d_transpose( + strides=(2, 2, 2), output_padding=(1, 1, 1), dtype=dtype + ) + conv3d_transpose = tvm.IRModule.from_expr(conv3d_transpose) + config = conv3d_transpose, dic, param_lst + run_and_verify_func(config, run_module=run_module, dtype=dtype) + + +def test_conv3d_transpose_pattern(run_module, dtype="float32"): + activation_lst = [None, "relu", "tanh", "sigmoid"] + for a in activation_lst: + conv3d, dic, param_lst = get_conv3d_transpose(activation=a, dtype=dtype) + conv3d = tvm.IRModule.from_expr(conv3d) + config = conv3d, dic, param_lst + run_and_verify_func(config, run_module=run_module, dtype=dtype) + + conv3d_bias, dic, param_lst = get_conv3d_transpose_bias(activation=a, dtype=dtype) + conv3d_bias = tvm.IRModule.from_expr(conv3d_bias) + config = conv3d_bias, dic, param_lst + run_and_verify_func(config, run_module=run_module, dtype=dtype) + + def test_dense(run_module, dtype="float32"): x_shape = (1, 16) k_shape = (32, 16) @@ -344,6 +728,111 @@ def test_dense_pattern(run_module, dtype="float32"): run_and_verify_func(config, run_module=run_module, dtype=dtype) +def test_pool2d(run_module, dtype="float32"): + def get_graph( + op, + x_shape=(1, 3, 32, 32), + pool_size=(2, 2), + strides=(2, 2), + padding=(0, 0), + ceil_mode=False, + count_include_pad=None, + ): + x = relay.var("x", shape=(x_shape), dtype=dtype) + if count_include_pad is not None: + out = op( + x, + pool_size=pool_size, + strides=strides, + padding=padding, + ceil_mode=ceil_mode, + count_include_pad=count_include_pad, + ) + else: + out = op( + x, + pool_size=pool_size, + strides=strides, + padding=padding, + ceil_mode=ceil_mode, + ) + out = tvm.IRModule.from_expr(out) + return out, {"x": x_shape}, [] + + for pool_size in [(2, 2), (3, 3)]: + for strides in [(1, 1), (2, 2)]: + for padding in [(0, 0), (1, 1), (0, 0, 1, 1)]: + for ceil_mode in [False]: + # Skip "the padding size is larger than or equal to the filter size for exclusive-counting pooling" + if pool_size == (2, 2) and padding == (0, 0, 1, 1): + continue + for count_include_pad in [False, True]: + # Skip "inclusive-counted blended or average pooling is not supported in combination with asymmetric padding" + if count_include_pad and (padding == (0, 0, 1, 1) or strides == (2, 2)): + continue + run_and_verify_func( + get_graph( + relay.nn.avg_pool2d, + pool_size=pool_size, + strides=strides, + padding=padding, + ceil_mode=ceil_mode, + count_include_pad=count_include_pad, + ), + run_module=run_module, + ) + run_and_verify_func( + get_graph( + relay.nn.max_pool2d, + pool_size=pool_size, + strides=strides, + padding=padding, + ceil_mode=ceil_mode, + ), + run_module=run_module, + ) + + +def test_pool3d(run_module, dtype="float32"): + def get_graph( + op, + x_shape=(1, 3, 8, 32, 32), + pool_size=(2, 2, 2), + strides=(2, 2, 2), + padding=(0, 0, 0), + ceil_mode=False, + count_include_pad=None, + dtype="float32", + ): + x = relay.var("x", shape=(x_shape), dtype=dtype) + if count_include_pad is not None: + out = op( + x, + pool_size=pool_size, + strides=strides, + padding=padding, + ceil_mode=ceil_mode, + count_include_pad=count_include_pad, + ) + else: + out = op( + x, + pool_size=pool_size, + strides=strides, + padding=padding, + ceil_mode=ceil_mode, + ) + out = tvm.IRModule.from_expr(out) + return out, {"x": x_shape}, [] + + run_and_verify_func(get_graph(relay.nn.avg_pool3d), run_module=run_module) + run_and_verify_func(get_graph(relay.nn.max_pool3d), run_module=run_module) + run_and_verify_func( + get_graph(relay.nn.max_pool3d, padding=(0, 0, 0, 1, 1, 1)), run_module=run_module + ) + run_and_verify_func(get_graph(relay.nn.max_pool3d, strides=(1, 1, 1)), run_module=run_module) + + if __name__ == "__main__": import sys diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py index 736ece265bde..080c6c803961 100644 --- a/tests/python/relay/test_pass_partition_graph.py +++ b/tests/python/relay/test_pass_partition_graph.py @@ -925,7 +925,13 @@ def test_dnnl_fuse(): conv2d_bias_pat, conv2d_relu_pat, conv2d_sigmoid_pat, - ) = (dnnl_patterns[0], dnnl_patterns[4], dnnl_patterns[6], dnnl_patterns[8], dnnl_patterns[12]) + ) = ( + dnnl_patterns[1], + dnnl_patterns[13], + dnnl_patterns[19], + dnnl_patterns[25], + dnnl_patterns[37], + ) def get_blocks( prefix, @@ -1033,8 +1039,8 @@ def test_partition(): def test_partition_mobilenet(): mod, params = relay.testing.mobilenet.get_workload() mod = get_partitoned_mod(mod, params, dnnl_patterns) - # 27 fused conv + bn + relu and one dense - assert len(mod.functions) - 1 == 28 # -1 for main + # 27 fused conv + bn + relu, one dense and one softmax + assert len(mod.functions) - 1 == 29 # -1 for main def test_exec(mod, params, ref_mod, ref_params, out_shape): ishape = (1, 3, 224, 224)