diff --git a/python/tvm/relay/op/contrib/dnnl.py b/python/tvm/relay/op/contrib/dnnl.py index 6581f10a2f56..c251b66bfbc7 100644 --- a/python/tvm/relay/op/contrib/dnnl.py +++ b/python/tvm/relay/op/contrib/dnnl.py @@ -39,13 +39,14 @@ from tvm.relay import transform from tvm.relay.expr import GlobalVar from tvm.relay.expr_functor import ExprMutator, ExprVisitor +from tvm.relay.expr import const from tvm.relay.analysis import analysis as _analysis from tvm.relay import expr as _expr from ... import _ffi_api -from ...dataflow_pattern import wildcard, is_op, is_expr, rewrite, DFPatternCallback +from ...dataflow_pattern import wildcard, is_op, is_constant, is_expr, rewrite, DFPatternCallback from .register import register_pattern_table @@ -56,8 +57,8 @@ def _register_external_op_helper(op_name, supported=True): """The helper function to indicate that a given operator can be supported by DNNL. - Paramters - --------- + Parameters + ---------- op_name : Str The name of operator that will be registered. @@ -69,6 +70,10 @@ def _register_external_op_helper(op_name, supported=True): @tvm.ir.register_op_attr(op_name, "target.dnnl") def _func_wrapper(expr): + args = expr.args + if any([x.checked_type.dtype == "int64" for x in args]): + logger.info("DNNL does not support int64.") + return False return supported return _func_wrapper @@ -90,6 +95,7 @@ def _func_wrapper(expr): _register_external_op_helper("exp") _register_external_op_helper("log") _register_external_op_helper("sqrt") +_register_external_op_helper("round") _register_external_op_helper("nn.relu") _register_external_op_helper("nn.leaky_relu") _register_external_op_helper("tanh") @@ -199,6 +205,70 @@ def make_dnnl_pattern(op_name, with_bias, with_eltwise): return dnnl_pattern +def make_qnn_conv2d_pattern(): + """Make qnn.conv2d based pattern supported by DNNL + + Returns + ------- + pattern : Tuple(pattern_name, CallPattern) + Created pattern name, along with its CallPattern. + """ + data = wildcard() + weight = is_constant() + bias = is_constant() + o_scl = is_constant() + dst_zp = is_constant() + act_scl = is_constant() + sum_scl = is_constant() + sum_src = wildcard() + + zero_zp = is_expr(const(0, dtype="int32")) + + pat = is_op("qnn.conv2d")(data, weight, zero_zp, zero_zp, is_constant(), is_constant()) + pat = is_op("cast")(pat) + pat = is_op("add")(pat, bias) | pat # optional bias + pat = is_op("multiply")(pat, o_scl) + pat = is_op("clip")(pat) # TBD, not only clip + pat = is_op("multiply")(pat, act_scl) | pat # optional multiply. Ex: act_scl == 1 + pat = is_op("add")(pat, sum_scl * is_op("cast")(sum_src)) | pat # optional sum + pat = is_op("add")(pat, dst_zp) | pat # optional dst_zp, can be dst_zp == 0 + pat = is_op("cast")(pat) + + return "dnnl.qnn.conv2d", pat + + +def make_qnn_dense_pattern(): + """Make qnn.dense based pattern supported by DNNL + + Returns + ------- + pattern : Tuple(pattern_name, CallPattern) + Created pattern name, along with its CallPattern. + """ + data = wildcard() + weight = is_constant() + bias = is_constant() + o_scl = is_constant() + dst_zp = is_constant() + act_scl = is_constant() + sum_scl = is_constant() + sum_src = wildcard() + + zero_zp = is_expr(const(0, dtype="int32")) + + pat = is_op("qnn.dense")(data, weight, zero_zp, zero_zp, is_constant(), is_constant()) + pat = is_op("cast")(pat) + pat = is_op("add")(pat, bias) | pat # optional bias + pat = is_op("multiply")(pat, o_scl) + pat = is_op("clip")(pat) # TBD, not only clip + pat = is_op("multiply")(pat, act_scl) | pat # optional multiply. ex act_scl == 1 + pat = is_op("add")(pat, sum_scl * is_op("cast")(sum_src)) | pat # optional sum + pat = is_op("add")(pat, dst_zp) | pat # optional dst_zp, can be dst_zp == 0 + pat = is_op("cast")(pat) + + return "dnnl.qnn.dense", pat + + @register_pattern_table("dnnl") def pattern_table(): """Create dnnl patterns. @@ -208,8 +278,11 @@ def pattern_table(): dnnl_patterns : List[dnnl_pattern] Created patterns. """ + dnnl_patterns = list() + dnnl_patterns.append(make_qnn_conv2d_pattern()) + dnnl_patterns.append(make_qnn_dense_pattern()) + elt_list = ["nn.relu", "tanh", "sigmoid", "gelu", None] - dnnl_patterns = [] for with_bias in [True, False]: for elt in elt_list: if not with_bias and not elt: @@ -707,3 +780,201 @@ def rewrite_dense_bias_gelu_reshape_last(mod): [DenseReshapeBiasGeluRewrite(), DenseReshapeBiasGeluRewrite(has_gelu=False)], mod["main"] ) return mod + + +class LegalizeQnnOpForDnnl(DFPatternCallback): + """Legalize QNN based patterns to match DNNL + + original pattern: + OP = qnn.dense | qnn.conv2d + %1 = OP(SRC, WGH) - OP(src_zp, WGH) // qnn.conv2d + %2 = %1 + orig_bias // bias + %2 = (%1 - rq_in_zp) * rq_in_scl / rq_out_scl + rq_out_zp // qnn.requantize + %3 = act(%2) // activation == clip + %4 = ((%3 - sum_lh_zp) * sum_lh_scl + (SRC2 - sum_rh_zp) * sum_rh_scl) // qnn.add + / sum_out_scl + sum_out_zp + + transform to DNNL compatible: + %1 = OP(SRC, WGH) + %2 = cast(%1, dtype="float") + %2 = (%1 + bias) * o_scl + %3 = act(%2) * act_scl + %4 = %3 + SRC2 * sum_scl + %5 = %4 + dst_zp + %6 = cast(%5, dtype="float") + + where: + o_scl = rq_in_scl / rq_out_scl + act_scl = sum_lhs_scl / sum_out_scl + sum_scl = sum_rhs_scl / sum_out_scl + bias = orig_bias - OP(src_zp, WGH) - rq_in_zp + rq_out_zp * rq_out_scl / rq_in_scl + dst_zp = sum_out_zp - sum_lhs_zp * sum_lhs_scl / sum_out_scl - + sum_rhs_zp * sum_rhs_scl / sum_out_scl + """ + + def __init__(self): + super(LegalizeQnnOpForDnnl, self).__init__() + self.src = wildcard() + self.wgh = wildcard() + self.bias = wildcard() + self.sum_src = wildcard() + + self.src_scl = is_constant() + self.src_zp = is_constant() + self.wgh_scl = is_constant() + self.wgh_zp = is_expr(const(0)) + + self.rq_in_scl = is_constant() + self.rq_in_zp = is_constant() + self.rq_out_scl = is_constant() + self.rq_out_zp = is_constant() + + self.sum_lhs_scl = is_constant() + self.sum_lhs_zp = is_constant() + self.sum_rhs_scl = is_constant() + self.sum_rhs_zp = is_constant() + self.sum_out_scl = is_constant() + self.sum_out_zp = is_constant() + + self.root = (is_op("qnn.conv2d") | is_op("qnn.dense"))( + self.src, self.wgh, self.src_zp, self.wgh_zp, self.src_scl, self.wgh_scl + ) + pat = is_op("add")(self.root, self.bias) | self.root # optional bias + pat = is_op("qnn.requantize")( + pat, self.rq_in_scl, self.rq_in_zp, self.rq_out_scl, self.rq_out_zp + ) + pat = is_op("clip")(pat) + cast = is_op("cast")(pat) + pat = is_op("qnn.add")( + cast, + self.sum_src, + self.sum_lhs_scl, + self.sum_lhs_zp, + self.sum_rhs_scl, + self.sum_rhs_zp, + self.sum_out_scl, + self.sum_out_zp, + ) + pat = is_op("clip")(pat) + self.pattern = pat | cast + + def callback(self, pre, post, node_map): + root = node_map[self.root][0] + src = node_map[self.src][0] + wgh = node_map[self.wgh][0] + bias = node_map.get(self.bias, default=[relay.const(0, dtype="int32")])[0] + src_zp = node_map[self.src_zp][0] + rq_in_scl = node_map[self.rq_in_scl][0] + rq_in_zp = node_map[self.rq_in_zp][0] + rq_out_scl = node_map[self.rq_out_scl][0] + rq_out_zp = node_map[self.rq_out_zp][0] + + final_dtype = node_map[self.pattern][0].checked_type.dtype + + if root.op == relay.op.get("qnn.conv2d"): + dst_layout = root.attrs.out_layout + dst_layout = root.attrs.data_layout if dst_layout == "" else dst_layout + wgh_layout = root.attrs.kernel_layout + else: + # qnn.dense has no layout attributes. Assume that is plain + dst_layout = "NC" + wgh_layout = "OI" + + # TODO(@apeskov): dst_layout may ne blocked + bias_rank = len(dst_layout) - dst_layout.index("C") + + sum_src = node_map[self.sum_src][0] if self.sum_src in node_map else None + # Default values if qnn.sum is not present + sum_lhs_scl = node_map[self.sum_lhs_scl][0] if sum_src else relay.const(1, dtype="float32") + sum_lhs_zp = node_map[self.sum_lhs_zp][0] if sum_src else relay.const(0, dtype="int32") + sum_rhs_scl = node_map[self.sum_rhs_scl][0] if sum_src else relay.const(0, dtype="float32") + sum_rhs_zp = node_map[self.sum_rhs_zp][0] if sum_src else relay.const(0, dtype="int32") + sum_out_scl = node_map[self.sum_out_scl][0] if sum_src else relay.const(1, dtype="float32") + sum_out_zp = node_map[self.sum_out_zp][0] if sum_src else relay.const(0, dtype="int32") + + def cast_fp(op): + return relay.op.cast(op, dtype="float32") + + # recalculate some factors + o_scl = rq_in_scl / rq_out_scl + act_scl = sum_lhs_scl / sum_out_scl + sum_scl = sum_rhs_scl / sum_out_scl + dst_zp = ( + cast_fp(sum_out_zp) + - cast_fp(sum_lhs_zp) * sum_lhs_scl / sum_out_scl + - cast_fp(sum_rhs_zp) * sum_rhs_scl / sum_out_scl + ) + bias = self.squeeze_bias(bias, dst_layout) + bias = ( + cast_fp(bias) + - cast_fp(self.fake_op(src_zp, wgh, wgh_layout)) + - cast_fp(rq_in_zp) + + cast_fp(rq_out_zp) * rq_out_scl / rq_in_scl + ) + bias = self.broadcast_to_rank(bias, bias_rank) + + zero_zp = relay.const(0, dtype="int32") + one_scl = relay.const(1.0, dtype="float32") + + # construct new graph with proper post op ordering + gr = tvm.relay.Call( + root.op, + [src, wgh, zero_zp, zero_zp, one_scl, one_scl], + root.attrs, + root.type_args, + root.span, + ) + gr = relay.op.cast(gr, dtype="float32") + gr = gr + bias + gr = gr * o_scl + gr = relay.op.clip(gr, 0, 255) * act_scl + gr = gr + sum_scl * cast_fp(sum_src) if sum_src else gr + gr = gr + dst_zp + gr = relay.op.cast(gr, dtype=final_dtype) + return gr + + @staticmethod + def fake_op(zp, wgh, layout): + """Fake operator implementation for zp broadcast input""" + # Conv: reduce kernel {OC, IC, KH, KW} -> {OC} in case of group that is still correct + # Dense: reduce kernel {OC, IC} -> {OC} + wgh_int = relay.op.cast(wgh, dtype="int32") + reduced_kernel = relay.op.sum( + wgh_int, axis=[layout.index("O")], keepdims=False, exclude=True + ) + return zp * reduced_kernel + + @staticmethod + def squeeze_bias(bias, layout): + shape = transform.InferTypeLocal(bias).concrete_shape + c_position = layout.index("C") - len(layout) + len(shape) + squeeze_idxs = [i for i in range(len(shape)) if i != c_position] + return relay.op.squeeze(bias, squeeze_idxs) + + @staticmethod + def broadcast_to_rank(op, rank): + """Scalar or 1D tensor are supported""" + shape = transform.InferTypeLocal(op).concrete_shape + if len(shape) == 0: + return op + if len(shape) == 1: + return relay.op.expand_dims(op, 1, rank - 1) + raise ValueError("Unexpected bias rank to broadcast. Only 0 and 1 are supported.") + + +def legalize_qnn_for_dnnl(mod): + """Transform qnn primitives to DNNL compatible form. Eliminate source zero point and apply + strict sequence of post ops.""" + mod["main"] = rewrite(LegalizeQnnOpForDnnl(), mod["main"]) + + seq = tvm.transform.Sequential( + [ + transform.InferType(), + # transform.SimplifyInference(), # TODO: this pass decompose nn.layer_norm + # transform.FoldScaleAxis(), # TODO: fail inside TVM in case of grouped convolutions. + transform.FoldConstant(), + ] + ) + with tvm.transform.PassContext(opt_level=3): + mod = seq(mod) + return mod diff --git a/src/relay/backend/contrib/dnnl/codegen.cc b/src/relay/backend/contrib/dnnl/codegen.cc index 927cd12ae0fb..f17cdafa76a5 100644 --- a/src/relay/backend/contrib/dnnl/codegen.cc +++ b/src/relay/backend/contrib/dnnl/codegen.cc @@ -35,6 +35,7 @@ #include #include "../../utils.h" +#include "comp_op_matcher.h" #ifdef USE_JSON_RUNTIME #include "../../../../runtime/contrib/json/json_node.h" @@ -436,6 +437,30 @@ class DNNLModuleCodegen : public CSourceModuleCodegenBase { #else // DNNL JSON runtime +/*! + * \brief Replace var expr which bind with args of call node + * + * \param args vector of expression (contains vars or constant nodes) + * \param cn call node which describe mapping of internal body vars with args + * \return updated vector of expressions + */ +static tvm::Array BindToCallNodeArgs(const std::vector& args, const CallNode* cn) { + tvm::Array res; + for (const auto& arg : args) { + if (arg->IsInstance()) { + res.push_back(arg); + } else { + auto body_params = cn->op.as()->params; + auto found = std::find(body_params.begin(), body_params.end(), arg); + ICHECK(found != body_params.end()); + auto idx = std::distance(body_params.begin(), found); + res.push_back(cn->args[idx]); + } + } + return res; +} + +/*! \brief Serializer to DNNL JSON runtime module */ class DNNLJSONSerializer : public backend::contrib::JSONSerializer { using JSONGraphNode = tvm::runtime::json::JSONGraphNode; using JSONGraphNodeEntry = tvm::runtime::json::JSONGraphNodeEntry; @@ -475,14 +500,19 @@ class DNNLJSONSerializer : public backend::contrib::JSONSerializer { } public: - DNNLJSONSerializer(const std::string& symbol, const Expr& expr) : JSONSerializer(symbol, expr) {} + DNNLJSONSerializer(const std::string& symbol, const Expr& expr) + : JSONSerializer("dnnl_" + symbol, expr) {} std::vector VisitExpr_(const CallNode* cn) override { Expr expr = GetRef(cn); std::string name; + tvm::Array args; + std::unordered_map extra_attrs; + const CallNode* call = cn; if (const auto* op_node = cn->op.as()) { name = op_node->name; + args = cn->args; } else if (const auto* fn = cn->op.as()) { auto comp = fn->GetAttr(attr::kComposite); ICHECK(comp.defined()) << "DNNL JSON runtime only supports composite functions."; @@ -511,15 +541,24 @@ class DNNLJSONSerializer : public backend::contrib::JSONSerializer { } else if (name.find("dnnl.dense") != std::string::npos) { call = GetRootCall(fn->body.as(), 10, "nn.dense"); ICHECK(call->op.as()) << "Not op node"; + } else if (name.find("dnnl.qnn.conv2d") != std::string::npos || + name.find("dnnl.qnn.dense") != std::string::npos) { + std::vector args_loc; + call = ParseComposite(*fn, &extra_attrs, &args_loc); + args = BindToCallNodeArgs(args_loc, cn); } else { LOG(FATAL) << "Unrecognized DNNL pattern: " << name; } + + if (args.empty()) { + args = cn->args; + } } else { LOG(FATAL) << "DNNL JSON runtime does not support calls to " << cn->op->GetTypeKey(); } std::vector inputs; - for (const auto& arg : cn->args) { + for (const auto& arg : args) { auto res = VisitExpr(arg); inputs.insert(inputs.end(), res.begin(), res.end()); } @@ -527,6 +566,8 @@ class DNNLJSONSerializer : public backend::contrib::JSONSerializer { "kernel", /* op_type_ */ inputs, 1 /* num_outputs_ */); SetCallNodeAttribute(node, call); + for (const auto& kvp : extra_attrs) node->SetAttr(kvp.first, kvp.second); + return AddNode(node, GetRef(cn)); } }; @@ -558,6 +599,61 @@ runtime::Module DNNLCompiler(const ObjectRef& ref) { TVM_REGISTER_GLOBAL("relay.ext.dnnl").set_body_typed(DNNLCompiler); +/*! + * \brief Constant Updater for DNNL JSON runtime + * + * Not all originally existing ConstantNode should be passed to JSON runtime. + * Some of them may be skipped or change ordering. So we have to apply the same traversing through + * the graph as DNNLJSONSerializer. + */ +struct DNNLConstantUpdater : public ConstantUpdater { + public: + DNNLConstantUpdater(const std::string& symbol, + std::unordered_map* params) + : ConstantUpdater("dnnl_" + symbol, params) {} + using ConstantUpdater::VisitExpr_; + + void VisitExpr_(const CallNode* cn) final { + this->VisitSpan(cn->span); + + if (const auto* fn = cn->op.as()) { + std::vector args_loc; + std::unordered_map attrs; + auto root_cn = ParseComposite(*fn, &attrs, &args_loc); + + auto args = root_cn ? BindToCallNodeArgs(args_loc, cn) : cn->args; + + // Customized visit order of args + for (const auto& arg : args) { + this->VisitExpr(arg); + } + } else { + // Original visit order of args + for (auto arg : cn->args) { + this->VisitExpr(arg); + } + } + } +}; + +/*! + * \brief The external compiler/codegen tool. It takes a Relay expression/module and + * produce collection of required constant NDArrays. + */ +Map DNNLConstantUpdaterFunc(Expr expr, std::string symbol) { + // Visit all suitable constant nodes + std::unordered_map res; + DNNLConstantUpdater const_updater(symbol, &res); + const_updater(expr); + + // Convert to tvm::Map + Map ret; + for (const auto& kvp : res) ret.Set(kvp.first, kvp.second); + return ret; +} + +TVM_REGISTER_GLOBAL("relay.ext.dnnl.constant_updater").set_body_typed(DNNLConstantUpdaterFunc); + } // namespace contrib } // namespace relay } // namespace tvm diff --git a/src/relay/backend/contrib/dnnl/comp_op_matcher.h b/src/relay/backend/contrib/dnnl/comp_op_matcher.h new file mode 100644 index 000000000000..364cc6e377ca --- /dev/null +++ b/src/relay/backend/contrib/dnnl/comp_op_matcher.h @@ -0,0 +1,245 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/backend/contrib/dnnl/comp_op_matcher.h + * \brief Implement matcher based function to parse complex composite nodes. + */ + +#ifndef TVM_RELAY_BACKEND_CONTRIB_DNNL_COMP_OP_MATCHER_H_ +#define TVM_RELAY_BACKEND_CONTRIB_DNNL_COMP_OP_MATCHER_H_ + +#include + +#include +#include +#include + +#include "../../../ir/dataflow_matcher_impl.h" + +/*! + * \brief Converter value to dmlc attr acceptable format + * + * \tparam T type of value (auto deduction) + * \param val value to convert + * \return resulting dmlc object + */ +template ::value, bool> = true> +dmlc::any dmlc_attr(const T& val) { + std::vector attr; + attr.emplace_back(std::vector{std::to_string(val)}); + return dmlc::any{attr}; +} + +template ::value, bool> = true> +dmlc::any dmlc_attr(const T& val) { + std::vector attr; + attr.emplace_back(std::vector{val}); + return dmlc::any{attr}; +} + +template >::value, bool> = true> +dmlc::any dmlc_attr(const T& val) { + std::vector attr; + attr.emplace_back(val); + return dmlc::any{attr}; +} + +/*! \brief Constructor of const scalar expression with defined type */ +tvm::relay::Expr constant(float val) { + auto value = tvm::runtime::NDArray::Empty({}, tvm::DataType::Float(32), {kDLCPU, 0}); + value.CopyFromBytes(&val, sizeof(val)); + auto res = tvm::relay::Constant(value); + tvm::relay::transform::InferTypeLocal(res); + return res; +} + +/*! + * \brief Simple helper to accumulate composite function arguments and corresponding attributes + * with indexes of them. + */ +class ArgPacker { + public: + ArgPacker(std::unordered_map* attrs, std::vector* args) + : attrs_(attrs), args_(args) {} + + int Put(const tvm::relay::Expr& arg, std::string tag_name = "") { + if (!arg.defined()) return -1; + int idx = args_->size(); + args_->push_back(arg); + if (!tag_name.empty()) { + attrs_->operator[](tag_name) = dmlc_attr(idx); + } + return idx; + } + + private: + std::unordered_map* attrs_; + std::vector* args_; +}; + +const tvm::relay::CallNode* ParseQnnConvComp(const tvm::relay::FunctionNode& comp_fn, + std::unordered_map* ext_attrs, + std::vector* args) { + using namespace tvm::relay; + + // Pattern + auto src = IsWildcard(); + auto wgh = IsWildcard(); + auto sum_src = IsWildcard(); + auto bias = IsConstant(); + + auto o_scl = IsConstant(); + auto act_scl = IsConstant(); + auto sum_scl = IsConstant(); + auto dst_zp = IsConstant(); + + DFPattern cnv; + DFPattern pat; + + cnv = IsOp("qnn.conv2d")({src, wgh, IsConstant(), IsConstant(), IsConstant(), IsConstant()}); + pat = IsOp("cast")({cnv}); + pat = IsOp("add")({pat, bias}) || pat; + pat = IsOp("multiply")({pat, o_scl}); + pat = IsOp("clip")({pat}); + pat = IsOp("multiply")({pat, act_scl}) || pat; + pat = IsOp("add")({pat, sum_scl * IsOp("cast")({sum_src})}) || pat; + pat = IsOp("add")({pat, dst_zp}) || pat; + pat = IsOp("cast")({pat}); + + // Check pattern match + auto indexed_body = CreateIndexedGraph(comp_fn.body); + DFPatternMatcher matcher(indexed_body.get()); + auto res = matcher.Match(pat, comp_fn.body); + ICHECK(res) << "Mismatch of DNNL partitioner and codegen logic"; + + // Handle arguments in deterministic order + auto map = matcher.GetMemo(); + auto find = [&map](const DFPattern& pat) -> tvm::relay::Expr { + if (map.count(pat)) return map.at(pat)[0]; + return {}; + }; + + ArgPacker arg_holder(ext_attrs, args); + arg_holder.Put(find(src)); + arg_holder.Put(find(wgh)); + arg_holder.Put(find(bias), "bias_idx"); + arg_holder.Put(find(sum_src), "sum_idx"); + arg_holder.Put(find(o_scl), "o_scl_idx"); + arg_holder.Put(find(act_scl), "act_scl_idx"); + arg_holder.Put(find(sum_scl), "sum_scl_idx"); + arg_holder.Put(find(dst_zp), "dst_zp_idx"); + + // Activation. Default clip to simulate relu via uint8 cast + std::vector clip_attr{"clip"}; + auto act_scl_val = map.count(act_scl) ? find(act_scl) : constant(1.0); + clip_attr.push_back(std::to_string(arg_holder.Put(act_scl_val))); // act_scale + clip_attr.push_back(std::to_string(arg_holder.Put(constant(0.0)))); // alpha + clip_attr.push_back(std::to_string(arg_holder.Put(constant(255.0)))); // beta + (*ext_attrs)["activation"] = dmlc_attr(clip_attr); + + return map.at(cnv)[0].as(); +} + +const tvm::relay::CallNode* ParseQnnDenseComp(const tvm::relay::FunctionNode& comp_fn, + std::unordered_map* ext_attrs, + std::vector* args) { + using namespace tvm::relay; + + // Pattern + auto src = IsWildcard(); + auto wgh = IsWildcard(); + auto sum_src = IsWildcard(); + auto bias = IsConstant(); + + auto o_scl = IsConstant(); + auto act_scl = IsConstant(); + auto sum_scl = IsConstant(); + auto dst_zp = IsConstant(); + + DFPattern dns, act, pat; + + dns = IsOp("qnn.dense")({src, wgh, IsConstant(), IsConstant(), IsConstant(), IsConstant()}); + pat = IsOp("cast")({dns}); + pat = IsOp("add")({pat, bias}) || pat; + pat = IsOp("multiply")({pat, o_scl}); + pat = IsOp("clip")({pat}); + pat = IsOp("multiply")({pat, act_scl}) || pat; + pat = IsOp("add")({pat, sum_scl * IsOp("cast")({sum_src})}) || pat; + pat = IsOp("add")({pat, dst_zp}) || pat; + pat = IsOp("cast")({pat}); + + // Check pattern match + auto indexed_body = CreateIndexedGraph(comp_fn.body); + DFPatternMatcher matcher(indexed_body.get()); + auto res = matcher.Match(pat, comp_fn.body); + ICHECK(res) << "Mismatch of DNNL partitioner and codegen logic"; + + // Handle arguments in deterministic order + auto memo = matcher.GetMemo(); + auto find = [&memo](const DFPattern& pat) -> tvm::relay::Expr { + if (memo.count(pat)) return memo.at(pat)[0]; + return {}; + }; + + ArgPacker arg_holder(ext_attrs, args); + arg_holder.Put(find(src)); + arg_holder.Put(find(wgh)); + arg_holder.Put(find(bias), "bias_idx"); + arg_holder.Put(find(sum_src), "sum_idx"); + arg_holder.Put(find(o_scl), "o_scl_idx"); + arg_holder.Put(find(act_scl), "act_scl_idx"); + arg_holder.Put(find(sum_scl), "sum_scl_idx"); + arg_holder.Put(find(dst_zp), "dst_zp_idx"); + + // Activation. Default clip to simulate relu via uint8 cast + std::vector clip_attr{"clip"}; + auto act_scl_val = memo.count(act_scl) ? find(act_scl) : constant(1.0); + clip_attr.push_back(std::to_string(arg_holder.Put(act_scl_val))); // act_scale + clip_attr.push_back(std::to_string(arg_holder.Put(constant(0.0)))); // alpha + clip_attr.push_back(std::to_string(arg_holder.Put(constant(255.0)))); // beta + (*ext_attrs)["activation"] = dmlc_attr(clip_attr); + + return memo.at(dns)[0].as(); +} + +/*! + * Parse composite function and return real args, additional attributes and root call node + * @param comp_fn composite function to parse + * @param ext_attrs attr collection with additional attributes + * @param args real arguments of node + * @return root call node + */ +const tvm::relay::CallNode* ParseComposite(const tvm::relay::FunctionNode& comp_fn, + std::unordered_map* ext_attrs, + std::vector* args) { + auto comp = comp_fn.GetAttr(tvm::relay::attr::kComposite); + ICHECK(comp.defined()) << "DNNL JSON runtime only supports composite functions."; + auto name = comp.value(); + + const tvm::relay::CallNode* res = nullptr; + if (name == "dnnl.qnn.conv2d") + res = ParseQnnConvComp(comp_fn, ext_attrs, args); + else if (name == "dnnl.qnn.dense") + res = ParseQnnDenseComp(comp_fn, ext_attrs, args); + return res; +} + +#endif // TVM_RELAY_BACKEND_CONTRIB_DNNL_COMP_OP_MATCHER_H_ diff --git a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc index 5045f3323af7..a4239186b4b3 100644 --- a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc +++ b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc @@ -134,9 +134,56 @@ class DNNLJSONRuntime : public JSONRuntimeBase { {"tanh", dnnl::algorithm::eltwise_tanh}, {"sigmoid", dnnl::algorithm::eltwise_logistic}, {"clip", dnnl::algorithm::eltwise_clip}, + {"gelu_erf", dnnl::algorithm::eltwise_gelu_erf}, }; - bool ParsingOpName(const std::string op_name, dnnl::primitive_attr attr) { + dnnl::primitive_attr ParseAttrs(const size_t& nid, TensorRequisite* bias_tr) { + dnnl::primitive_attr attr; + + // Post op attributes based on named inputs. + auto dst_zp_tr = GetInputByName(nid, "dst_zp_idx"); + auto o_scl_tr = GetInputByName(nid, "o_scl_idx"); + auto sum_scl_tr = GetInputByName(nid, "sum_scl_idx"); + + if (o_scl_tr) { + ICHECK(o_scl_tr.IsConstant()); + auto data = o_scl_tr.GetConstDataLikeVec(); + attr.set_output_scales(data.size() == 1 ? 0 : (1 << 1), data); + } + + auto activation = GetNodeAttr>(nodes_[nid], "activation", {"none"}); + if (activation[0] != "none") { + auto a_type = elt_name2algo.at(activation[0]); + auto a_scale = GetInput(nid, std::stoi(activation[1])).GetConstScalarData(); + auto a_alfa = GetInput(nid, std::stoi(activation[2])).GetConstScalarData(); + auto a_beta = GetInput(nid, std::stoi(activation[3])).GetConstScalarData(); + + auto ops = attr.get_post_ops(); + ops.append_eltwise(a_scale, a_type, a_alfa, a_beta); + attr.set_post_ops(ops); + } + + if (sum_scl_tr) { + auto scl = sum_scl_tr.GetConstScalarData(); + auto ops = attr.get_post_ops(); + ops.append_sum(scl); + attr.set_post_ops(ops); + } + + if (dst_zp_tr) { + auto zp = dst_zp_tr.GetConstScalarData(); + // Use linear post op instead of set_zero_points(). Because of limitation of int32 type, + // but we have to use float. + auto ops = attr.get_post_ops(); + ops.append_eltwise(1.0, dnnl::algorithm::eltwise_linear, 1.0, zp); + attr.set_post_ops(ops); + } + *bias_tr = GetInputByName(nid, "bias_idx"); + + if (o_scl_tr || activation[0] != "none" || sum_scl_tr || dst_zp_tr) return attr; + + // parsing of name to extract attributes + auto op_name = nodes_[nid].GetOpName(); // Define RegExp. std::regex bias_add_pat(".*_bias.*"); std::regex relu_pat(".*_relu.*"); @@ -163,7 +210,9 @@ class DNNLJSONRuntime : public JSONRuntimeBase { } // Parsing bias_add. - return std::regex_match(op_name, bias_add_pat) ? true : false; + *bias_tr = std::regex_match(op_name, bias_add_pat) ? GetInput(nid, 2) : TensorRequisite{}; + + return attr; } // Build up the engine based on the input graph. @@ -219,16 +268,16 @@ class DNNLJSONRuntime : public JSONRuntimeBase { void Convolution(const size_t& nid) { auto node = nodes_[nid]; - auto op_name = node.GetOpName(); - dnnl::primitive_attr attr; - attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); - bool has_bias = ParsingOpName(op_name, attr); // Setup attributes. auto src_tr = GetInput(nid, 0); auto wgh_tr = GetInput(nid, 1); auto dst_tr = GetOutput(nid, 0); - auto bias_tr = has_bias ? GetInput(nid, 2) : GetInput(nid, -1); + auto bias_tr = TensorRequisite{}; + + auto attr = ParseAttrs(nid, &bias_tr); + attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); + auto strides = GetNodeAttr>(node, "strides"); auto dilates = GetNodeAttr>(node, "dilation"); auto padding = GetNodeAttr>(node, "padding"); @@ -292,25 +341,29 @@ class DNNLJSONRuntime : public JSONRuntimeBase { auto scratchpad_tr = TensorRequisite::AsIs(conv_prim_desc.scratchpad_desc()); - Submit(dnnl::convolution_forward(conv_prim_desc), {{DNNL_ARG_SRC, src_tr}, - {DNNL_ARG_WEIGHTS, wgh_tr}, - {DNNL_ARG_BIAS, bias_tr}, - {DNNL_ARG_SCRATCHPAD, scratchpad_tr}, - {DNNL_ARG_DST, dst_tr}}); + // TODO(@apeskov): Simulation of inplace primitive. just as PoC. + auto sum_in_tr = GetInputByName(nid, "sum_idx").TreatAs(dst_layout); + + Submit(dnnl::convolution_forward(conv_prim_desc), + {{DNNL_ARG_SRC, src_tr}, + {DNNL_ARG_WEIGHTS, wgh_tr}, + {DNNL_ARG_BIAS, bias_tr}, + {DNNL_ARG_SCRATCHPAD, scratchpad_tr}, + {DNNL_ARG_DST, dst_tr}}, + {sum_in_tr, DNNL_ARG_DST}); } void Deconvolution(const size_t& nid) { auto node = nodes_[nid]; - auto op_name = node.GetOpName(); - dnnl::primitive_attr attr; - attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); - bool has_bias = ParsingOpName(op_name, attr); // Setup attributes. auto src_tr = GetInput(nid, 0); auto wgh_tr = GetInput(nid, 1); auto dst_tr = GetOutput(nid, 0); - auto bias_tr = has_bias ? GetInput(nid, 2) : GetInput(nid, -1); + auto bias_tr = TensorRequisite{}; + + auto attr = ParseAttrs(nid, &bias_tr); + attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); auto strides = GetNodeAttr>(node, "strides"); auto dilates = GetNodeAttr>(node, "dilation"); @@ -374,16 +427,15 @@ class DNNLJSONRuntime : public JSONRuntimeBase { void Dense(const size_t& nid) { auto node = nodes_[nid]; - auto op_name = node.GetOpName(); - dnnl::primitive_attr attr; - attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); - bool has_bias = ParsingOpName(op_name, attr); // Setup attributes. auto src_tr = GetInput(nid, 0); auto wgh_tr = GetInput(nid, 1); auto dst_tr = GetOutput(nid, 0); - auto bias_tr = has_bias ? GetInput(nid, 2) : GetInput(nid, -1); + auto bias_tr = TensorRequisite{}; + + auto attr = ParseAttrs(nid, &bias_tr); + attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); // Assumption that bias is correct and can be squeezed to 1D bias_tr = bias_tr.Reshape({dst_tr.dims()[1]}); @@ -403,11 +455,16 @@ class DNNLJSONRuntime : public JSONRuntimeBase { auto scratchpad_tr = TensorRequisite::AsIs(dense_prim_desc.scratchpad_desc()); - Submit(dnnl::inner_product_forward(dense_prim_desc), {{DNNL_ARG_SRC, src_tr}, - {DNNL_ARG_WEIGHTS, wgh_tr}, - {DNNL_ARG_BIAS, bias_tr}, - {DNNL_ARG_SCRATCHPAD, scratchpad_tr}, - {DNNL_ARG_DST, dst_tr}}); + // TODO(@apeskov): Simulation of inplace primitive. just as PoC. + auto sum_in_tr = GetInputByName(nid, "sum_idx"); + + Submit(dnnl::inner_product_forward(dense_prim_desc), + {{DNNL_ARG_SRC, src_tr}, + {DNNL_ARG_WEIGHTS, wgh_tr}, + {DNNL_ARG_BIAS, bias_tr}, + {DNNL_ARG_SCRATCHPAD, scratchpad_tr}, + {DNNL_ARG_DST, dst_tr}}, + {sum_in_tr, DNNL_ARG_DST}); } void BatchNorm(const size_t& nid) { @@ -675,6 +732,11 @@ class DNNLJSONRuntime : public JSONRuntimeBase { return res; } + TensorRequisite GetInputByName(const size_t& nid, const std::string& name) { + auto idx = GetNodeAttr(nodes_[nid], name, {"-1"}); + return GetInput(nid, idx); + } + TensorRequisite GetOutput(const size_t& nid, const int idx) { if (idx == -1) return {}; // -1 reserved value for empty input. @@ -692,8 +754,8 @@ class DNNLJSONRuntime : public JSONRuntimeBase { } /*! \brief Helper function to register primitive into execution queue */ - void Submit(const dnnl::primitive& prim, - const std::unordered_map& tr_args) { + void Submit(const dnnl::primitive& prim, const std::unordered_map& tr_args, + const std::pair& inplace_conf = {}) { // Register all provided TR arguments std::unordered_map prim_arg_id; TensorRegistry::ActionQue post_prim_actions; @@ -706,6 +768,18 @@ class DNNLJSONRuntime : public JSONRuntimeBase { prim_arg_id[key] = arg_id; } + // Simulate inplace primitive + if (auto tr = inplace_conf.first) { + auto arg_id = tensor_registry_.Register(tr, &net_); + auto dst_tr = tr_args.at(inplace_conf.second); + auto dst_arg_id = prim_arg_id.at(inplace_conf.second); + + // Register copy action direct before main primitive + dnnl::reorder::primitive_desc io_copy_pd(engine_, tr.desc(), engine_, dst_tr.desc()); + net_.push_back( + {dnnl::reorder(io_copy_pd), {{DNNL_ARG_SRC, arg_id}, {DNNL_ARG_DST, dst_arg_id}}}); + } + // Register main primitive net_.push_back({prim, prim_arg_id}); diff --git a/src/runtime/contrib/dnnl/dnnl_tensor_requisite.h b/src/runtime/contrib/dnnl/dnnl_tensor_requisite.h index d02ceff5de82..bad4bc10edec 100644 --- a/src/runtime/contrib/dnnl/dnnl_tensor_requisite.h +++ b/src/runtime/contrib/dnnl/dnnl_tensor_requisite.h @@ -275,6 +275,7 @@ class TensorRequisite { * innermost. */ TensorRequisite TreatAs(const std::string& layout, std::string desired_logic_layout = "") const { + if (!defined()) return *this; if (desired_logic_layout.empty()) desired_logic_layout = DefaultLogicLayoutFor(layout); const auto origin_dims = dims(); diff --git a/tests/python/contrib/test_dnnl.py b/tests/python/contrib/test_dnnl.py index c884665421cb..2138eda08697 100755 --- a/tests/python/contrib/test_dnnl.py +++ b/tests/python/contrib/test_dnnl.py @@ -20,6 +20,7 @@ import sys import subprocess import math +import collections import tvm from tvm import relay @@ -51,7 +52,7 @@ def bf16_supported(): cpu_info = subprocess.check_output("sysctl -a", shell=True).strip().decode() for line in cpu_info.split("\n"): if line.startswith("hw.optional.avx512f"): - _bf16_supported = bool(line.split(":", 1)[1]) + _bf16_supported = bool(int(line.split(":", 1)[1])) elif sys.platform.startswith("linux"): _bf16_supported = "avx512" in open("/proc/cpuinfo", "r").read() return _bf16_supported @@ -114,6 +115,7 @@ def partition_for_dnnl(mod, params=None, alter_layout=True, prune_subgraphs=True mod = dnnl.rewrite_layer_norm(mod) mod = dnnl.rewrite_dense_bias_gelu_reshape_last(mod) + mod = dnnl.legalize_qnn_for_dnnl(mod) byoc_seq = tvm.transform.Sequential( [ @@ -1126,5 +1128,540 @@ def get_graph(act=None): ) +def permute_shape(shape, l_from="", l_to=""): + res_shape = [] + for label in l_to: + pos = l_from.find(label) + res_shape.append(shape[pos]) + + return res_shape + + +def expand_dim(shape, rank=0): + assert len(shape) == 1 + return shape + [1] * (rank - 1) + + +def filler_uni(low=0, high=1): + def filler_func(shape): + return np.random.uniform(low, high, shape) + + return filler_func + + +class QnnBuilder: + def __init__(self, qnn_profile=None): + self._args = {} + self._args_op = [] + self._qp = qnn_profile + + def arg(self, shape=[], dtype="float32", filler=filler_uni(), is_const=True): + if isinstance(filler, (int, float)): + value = np.full(shape, filler).astype(dtype) + else: + value = filler(shape).astype(dtype) + + if is_const: + res = relay.const(value, dtype=dtype) + else: + name = f"in_{len(self._args)}" + res = relay.var(name, shape=shape, dtype=dtype) + self._args[name] = value + self._args_op.append(res) + + return res + + def make_zp(self, mean_val, num_ch=1, dispersion=0.2): + if num_ch == 1: + return self.arg(shape=[], dtype="int32", filler=mean_val) + else: + low = int(mean_val * (1 - dispersion)) + high = int(mean_val * (1 + dispersion)) + return self.arg(shape=[num_ch], dtype="int32", filler=filler_uni(low, high)) + + def make_scl(self, mean_val, num_ch=1, dispersion=0.2): + if num_ch == 1: + return self.arg(shape=[], dtype="float32", filler=mean_val) + else: + low = mean_val * (1 - dispersion) + high = mean_val * (1 + dispersion) + return self.arg(shape=[num_ch], dtype="float32", filler=filler_uni(low, high)) + + def make_zp_and_scl(self, name, num_ch=1, dispersion=0.2): + is_per_channel = getattr(self._qp, f"{name}_pc") + zp_val = getattr(self._qp, f"{name}_zp") + scl_val = getattr(self._qp, f"{name}_scl") + + zp = self.make_zp(zp_val, num_ch if is_per_channel else 1, dispersion) + scl = self.make_scl(scl_val, num_ch if is_per_channel else 1, dispersion) + return zp, scl + + def finalize(self, op): + func = relay.Function(self._args_op, op) + mod = tvm.IRModule.from_expr(func) + mod = relay.transform.InferType()(mod) + return mod, self._args + + +def check_fully_annotated(mod, desired_compiler): + matched_ops = [] + other_ops = [] + + def _visit(node): + if isinstance(node, tvm.relay.Call): + op = node.op + if isinstance(op, relay.GlobalVar): + func = mod[op] + if "Compiler" in func.attrs and func.attrs["Compiler"] == desired_compiler: + matched_ops.append(op) + return + else: + other_ops.append(op) + + tvm.relay.analysis.post_order_visit(mod["main"].body, _visit) + + assert len(other_ops) == 0 and len(matched_ops) != 0, "Model is not fully DNNL compiled" + + +def check_result( + mod, + ref_mod, + map_inputs, + tol=1e-5, + target="llvm", + device=tvm.cpu(), + params=None, + ref_result=None, + atol=None, + desired_compiler="dnnl", +): + if atol is None: + atol = tol + + if desired_compiler is not None: + check_fully_annotated(mod, desired_compiler) + + if ref_result is None: + # Run the reference result + relay.backend.te_compiler.get().clear() + with tvm.transform.PassContext(opt_level=3): + ref_lib = relay.build(ref_mod, target=target, params=params) + ref_rt_mod = tvm.contrib.graph_executor.GraphModule(ref_lib["default"](device)) + + for name, data in map_inputs.items(): + ref_rt_mod.set_input(name, data) + ref_rt_mod.run() + out = ref_rt_mod.get_output(0) + ref_result = out.numpy() + + def check_vm_result(): + relay.backend.te_compiler.get().clear() + with tvm.transform.PassContext(opt_level=3): + exe = relay.vm.compile(mod, target=target, params=params) + code, lib = exe.save() + exe = tvm.runtime.vm.Executable.load_exec(code, lib) + vm = tvm.runtime.vm.VirtualMachine(exe, device) + output = vm.run(**map_inputs) + tvm.testing.assert_allclose(output.numpy(), ref_result, rtol=tol, atol=atol) + + def check_graph_executor_result(): + relay.backend.te_compiler.get().clear() + with tvm.transform.PassContext(opt_level=3): + lib = relay.build(mod, target=target, params=params) + rt_mod = tvm.contrib.graph_executor.GraphModule(lib["default"](device)) + + rt_mod.run(**map_inputs) + output = rt_mod.get_output(0) + tvm.testing.assert_allclose(output.numpy(), ref_result, rtol=tol, atol=atol) + + check_vm_result() + check_graph_executor_result() + + +ConvProfile = collections.namedtuple( + "ConvProfile", + [ + "SHAPE", + "KER", + "STR", + "PAD", + "DEL", + "OC", + "GR", + "D_LAYOUT", + "K_LAYOUT", + ], +) +base_conv = ConvProfile( + SHAPE=[1, 8, 5, 5], + KER=[3, 3], + STR=[1, 1], + PAD=[1, 1], + DEL=[1, 1], + OC=16, + GR=1, + D_LAYOUT="NCHW", + K_LAYOUT="OIHW", +) +base_conv_nhwc = base_conv._replace(D_LAYOUT="NHWC", K_LAYOUT="HWIO") +base_conv_dilated = base_conv._replace(PAD=[2, 2], DEL=[2, 2]) +base_conv_no_pad = base_conv._replace(PAD=[0, 0]) +base_conv_no_pad_nhwc = base_conv_no_pad._replace(D_LAYOUT="NHWC", K_LAYOUT="HWIO") +base_conv_group_no_pad = base_conv_no_pad._replace(GR=2) +base_conv_dw_no_pad = base_conv_no_pad._replace(SHAPE=[1, 16, 5, 5], GR=16) + + +DenseProfile = collections.namedtuple("DenseProfile", ["N", "IC", "OC"]) +base_dense_profile = DenseProfile(N=2, IC=10, OC=16) + +ArgConstConfig = collections.namedtuple("ArgConstConfig", ["Data", "Weights", "Bias", "Sum"]) +acp_regular = ArgConstConfig(Data=False, Weights=True, Bias=True, Sum=None) +acp_no_bias = ArgConstConfig(Data=False, Weights=True, Bias=None, Sum=None) +acp_with_sum = ArgConstConfig(Data=False, Weights=True, Bias=True, Sum=False) +acp_no_bias_with_sum = ArgConstConfig(Data=False, Weights=True, Bias=None, Sum=False) + +QuantizationConfig = collections.namedtuple( + "QuantizationConfig", + [ + "d_zp", + "d_scl", + "d_pc", + "k_zp", + "k_scl", + "k_pc", + "rq_zp", + "rq_scl", + "rq_pc", + "sum_zp", + "sum_scl", + "sum_pc", + "o_zp", + "o_scl", + "o_pc", + ], +) + +qp_regular = QuantizationConfig( + d_zp=0, + d_scl=0.2, + d_pc=False, + k_zp=0, + k_scl=0.1, + k_pc=False, + rq_zp=30, + rq_scl=0.2, + rq_pc=False, + sum_zp=15, + sum_scl=0.3, + sum_pc=False, + o_zp=5, + o_scl=0.2, + o_pc=False, +) +qp_asymmetric_data = qp_regular._replace( + d_zp=3, rq_zp=10, rq_scl=0.1, sum_zp=15, sum_scl=0.3, o_zp=4 +) + +qnn_conv_profiles = tvm.testing.parameter( + by_dict={ + # Pattern qnn.conv2d + qnn.requantize + "Base": (base_conv, acp_regular, qp_regular), + "NHWC": (base_conv_nhwc, acp_regular, qp_regular), + # Asymmetric input. NOTE: No pad! Input ZP is not compatible with padding + "Group": (base_conv_group_no_pad, acp_regular, qp_asymmetric_data), + "DW": (base_conv_dw_no_pad, acp_regular, qp_asymmetric_data), + "NoBias": (base_conv, acp_no_bias, qp_regular), + "AsymmetricInput": (base_conv_no_pad, acp_regular, qp_asymmetric_data), + "AsymmetricInput_NHWC": (base_conv_no_pad_nhwc, acp_regular, qp_asymmetric_data), + # Pattern Conv2d + Requantize + Sum + "WithSum": (base_conv_no_pad, acp_with_sum, qp_asymmetric_data), + "WithSum_NHWC": (base_conv_no_pad_nhwc, acp_with_sum, qp_asymmetric_data), + "WithSum_NoBias": (base_conv_no_pad, acp_no_bias_with_sum, qp_asymmetric_data), + } +) + + +@has_dnnl_codegen +def test_qnn_conv2d(qnn_conv_profiles): + def generate_model(p, c, q): + np.random.seed(0) + + N, IC, IH, IW = p.SHAPE + d_shape = p.SHAPE + w_shape = [p.OC, IC, *p.KER] + b_shape = [p.OC] + s_shape = [ + p.SHAPE[0], + p.OC, + (IH + 2 * p.PAD[0] - (p.KER[0] - 1) * p.DEL[0] - 1) // p.STR[0] + 1, + (IW + 2 * p.PAD[1] - (p.KER[1] - 1) * p.DEL[1] - 1) // p.STR[1] + 1, + ] + + if p.GR != 1: + w_shape[1] //= p.GR + + d_shape = permute_shape(d_shape, l_from="NCHW", l_to=p.D_LAYOUT) + s_shape = permute_shape(s_shape, l_from="NCHW", l_to=p.D_LAYOUT) + w_shape = permute_shape(w_shape, l_from="OIHW", l_to=p.K_LAYOUT) + + c_dim = p.D_LAYOUT.find("C") + b_shape = expand_dim(b_shape, rank=len(p.D_LAYOUT) - c_dim) + + bld = QnnBuilder(qnn_profile=q) + + # Start build a test graph + data = bld.arg(shape=d_shape, dtype="uint8", is_const=c.Data, filler=filler_uni(0, 20)) + d_zp, d_scl = bld.make_zp_and_scl("d", IC) + + # Convolution + wgh = bld.arg(shape=w_shape, dtype="int8", is_const=c.Weights, filler=filler_uni(-20, 20)) + w_zp, w_scl = bld.make_zp_and_scl("k") + + op = tvm.relay.qnn.op.conv2d( + data, + wgh, + d_zp, + w_zp, + d_scl, + w_scl, + kernel_size=p.KER, + padding=p.PAD, + strides=p.STR, + dilation=p.DEL, + groups=p.GR, + channels=p.OC, + out_dtype="int32", + data_layout=p.D_LAYOUT, + kernel_layout=p.K_LAYOUT, + ) + # Optional bias + if c.Bias is not None: + bias = bld.arg( + shape=b_shape, dtype="int32", is_const=c.Bias, filler=filler_uni(-50, 50) + ) + op = tvm.relay.add(op, bias) + + # Re-quantization + rq_in_zp = bld.make_zp(0) + rq_in_scl = bld.make_scl(q.d_scl * q.k_scl) # in real cases that should be a vector + rq_out_zp, rq_out_scl = bld.make_zp_and_scl("rq") + + op = tvm.relay.qnn.op.requantize( + op, rq_in_scl, rq_in_zp, rq_out_scl, rq_out_zp, out_dtype="int32" + ) + op = tvm.relay.clip( + op, a_min=0.0, a_max=255.0 + ) # pytorch frontend specific, I guess it's redundant + op = tvm.relay.cast(op, dtype="uint8") + + # Optional sum (ResNet like) + if c.Sum is not None: + sum_in = bld.arg(dtype="uint8", shape=s_shape, filler=filler_uni(0, 10), is_const=c.Sum) + + lhs_zp, lhs_scl = bld.make_zp_and_scl("rq") + rhs_zp, rhs_scl = bld.make_zp_and_scl("sum") + out_zp, out_scl = bld.make_zp_and_scl("o") + + op = tvm.relay.qnn.op.add(op, sum_in, lhs_scl, lhs_zp, rhs_scl, rhs_zp, out_scl, out_zp) + op = tvm.relay.clip(op, a_min=0.0, a_max=255.0) + + return bld.finalize(op) + + conv_p, arg_p, quant_p = qnn_conv_profiles + ref_mod, args = generate_model(conv_p, arg_p, quant_p) + mod = partition_for_dnnl(ref_mod) + + # atol=1 means int values should match with +-1 quantum value tolerance + check_result(mod, ref_mod, args, tol=1e-10, atol=1, desired_compiler="dnnl") + + +conv_profiles = tvm.testing.parameter( + by_dict={ + "Base": (base_conv, acp_regular), + "NHWC": (base_conv_nhwc, acp_regular), + "Group": (base_conv_group_no_pad, acp_regular), + "DW": (base_conv_dw_no_pad, acp_regular), + "Dilated": (base_conv_dilated, acp_regular), + } +) + + +@has_dnnl_codegen +def test_conv2d_plus(conv_profiles): + def generate_model(p, c): + np.random.seed(0) + + N, IC, IH, IW = p.SHAPE + d_shape = p.SHAPE + w_shape = [p.OC, IC, *p.KER] + b_shape = [p.OC] + s_shape = [ + p.SHAPE[0], + p.OC, + (IH + 2 * p.PAD[0] - (p.KER[0] - 1) * p.DEL[0] - 1) // p.STR[0] + 1, + (IW + 2 * p.PAD[1] - (p.KER[1] - 1) * p.DEL[1] - 1) // p.STR[1] + 1, + ] + + if p.GR != 1: + w_shape[1] //= p.GR + + d_shape = permute_shape(d_shape, l_from="NCHW", l_to=p.D_LAYOUT) + s_shape = permute_shape(s_shape, l_from="NCHW", l_to=p.D_LAYOUT) + w_shape = permute_shape(w_shape, l_from="OIHW", l_to=p.K_LAYOUT) + + c_dim = p.D_LAYOUT.find("C") + # b_shape = expand_dim(b_shape, rank=len(p.D_LAYOUT) - c_dim) + + bld = QnnBuilder() + + op = bld.arg(shape=d_shape, dtype="float32", is_const=c.Data) + wgh = bld.arg(shape=w_shape, dtype="float32", is_const=c.Weights) + op = tvm.relay.nn.conv2d( + op, + wgh, + kernel_size=p.KER, + padding=p.PAD, + strides=p.STR, + dilation=p.DEL, + groups=p.GR, + channels=p.OC, + out_dtype="float32", + data_layout=p.D_LAYOUT, + kernel_layout=p.K_LAYOUT, + ) + + if c.Bias is not None: + bias = bld.arg(shape=b_shape, dtype="float32", is_const=c.Bias) + op = tvm.relay.nn.bias_add(op, bias, axis=c_dim) + + if c.Sum is not None: + sum_in = bld.arg(shape=s_shape, dtype="float32", is_const=c.Sum) + op = tvm.relay.op.add(op, sum_in) + + return bld.finalize(op) + + conv_p, arg_p = conv_profiles + ref_mod, args = generate_model(conv_p, arg_p) + mod = partition_for_dnnl(ref_mod, alter_layout=False) + check_result(mod, ref_mod, args, tol=1e-5, desired_compiler="dnnl") + + +qnn_dense_profiles = tvm.testing.parameter( + by_dict={ + # Pattern Dense + Requantize + "Base": (base_dense_profile, acp_regular, qp_regular), + "AsymmetricInput": (base_dense_profile, acp_regular, qp_asymmetric_data), + # Pattern Dense + Requantize + Sum + "AsymmetricInput_Sum": (base_dense_profile, acp_with_sum, qp_asymmetric_data), + } +) + + +@has_dnnl_codegen +def test_qnn_dense(qnn_dense_profiles): + def generate_model(p, c, q): + np.random.seed(0) + + d_shape = [p.N, p.IC] + w_shape = [p.OC, p.IC] + b_shape = [p.OC] + s_shape = [p.N, p.OC] + + bld = QnnBuilder(qnn_profile=q) + + # Start build a test graph + data = bld.arg(shape=d_shape, dtype="uint8", is_const=c.Data, filler=filler_uni(0, 20)) + d_zp, d_scl = bld.make_zp_and_scl("d", p.IC) + + # Convolution + wgh = bld.arg(shape=w_shape, dtype="int8", is_const=c.Weights, filler=filler_uni(-20, 20)) + w_zp, w_scl = bld.make_zp_and_scl("k") + + op = tvm.relay.qnn.op.dense( + data, wgh, d_zp, w_zp, d_scl, w_scl, units=p.OC, out_dtype="int32" + ) + # Optional bias + if c.Bias is not None: + bias = bld.arg( + shape=b_shape, dtype="int32", is_const=c.Bias, filler=filler_uni(-50, 50) + ) + op = tvm.relay.add(op, bias) + + # Re-quantization + rq_in_zp = bld.make_zp(0) + rq_in_scl = bld.make_scl(q.d_scl * q.k_scl) # in real cases that should be a vector + rq_out_zp, rq_out_scl = bld.make_zp_and_scl("rq") + + op = tvm.relay.qnn.op.requantize( + op, rq_in_scl, rq_in_zp, rq_out_scl, rq_out_zp, out_dtype="int32" + ) + op = tvm.relay.clip( + op, a_min=0.0, a_max=255.0 + ) # pytorch frontend specific, I guess it's redundant + op = tvm.relay.cast(op, dtype="uint8") + + # Optional sum (ResNet like) + if c.Sum is not None: + sum_in = bld.arg(dtype="uint8", shape=s_shape, filler=filler_uni(0, 10), is_const=c.Sum) + + lhs_zp, lhs_scl = bld.make_zp_and_scl("rq") + rhs_zp, rhs_scl = bld.make_zp_and_scl("sum") + out_zp, out_scl = bld.make_zp_and_scl("o") + + op = tvm.relay.qnn.op.add(op, sum_in, lhs_scl, lhs_zp, rhs_scl, rhs_zp, out_scl, out_zp) + op = tvm.relay.clip(op, a_min=0.0, a_max=255.0) + + return bld.finalize(op) + + conv_p, arg_p, quant_p = qnn_dense_profiles + ref_mod, args = generate_model(conv_p, arg_p, quant_p) + mod = partition_for_dnnl(ref_mod) + + # atol=1 means int values should match with +-1 quantum value tolerance + check_result(mod, ref_mod, args, tol=1e-10, atol=1, desired_compiler="dnnl") + + +dense_profiles = tvm.testing.parameter( + by_dict={ + "Base": (base_dense_profile, acp_regular), + "WithSum": (base_dense_profile, acp_with_sum), + } +) + + +@has_dnnl_codegen +def test_dense_plus(dense_profiles): + def generate_model(p, c): + np.random.seed(0) + + d_shape = [p.N, p.IC] + w_shape = [p.OC, p.IC] + b_shape = [p.OC] + s_shape = [p.N, p.OC] + + c_dim = 1 + + bld = QnnBuilder() + + op = bld.arg(shape=d_shape, dtype="float32", is_const=c.Data) + wgh = bld.arg(shape=w_shape, dtype="float32", is_const=c.Weights) + op = tvm.relay.nn.dense(op, wgh, out_dtype="float32") + + if c.Bias is not None: + bias = bld.arg(shape=b_shape, dtype="float32", is_const=c.Bias) + op = tvm.relay.nn.bias_add(op, bias, axis=c_dim) + + if c.Sum is not None: + sum_in = bld.arg(shape=s_shape, dtype="float32", is_const=c.Sum) + op = tvm.relay.op.add(op, sum_in) + + return bld.finalize(op) + + dense_p, arg_p = dense_profiles + ref_mod, args = generate_model(dense_p, arg_p) + mod = partition_for_dnnl(ref_mod) + check_result(mod, ref_mod, args, tol=1e-5, desired_compiler="dnnl") + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py index dedeae56e9da..58b41189a0f0 100644 --- a/tests/python/relay/test_pass_partition_graph.py +++ b/tests/python/relay/test_pass_partition_graph.py @@ -926,11 +926,11 @@ def test_dnnl_fuse(): conv2d_relu_pat, conv2d_sigmoid_pat, ) = ( - dnnl_patterns[1], - dnnl_patterns[13], - dnnl_patterns[20], - dnnl_patterns[26], - dnnl_patterns[38], + dnnl_patterns[3], + dnnl_patterns[15], + dnnl_patterns[22], + dnnl_patterns[28], + dnnl_patterns[40], ) def get_blocks(