diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index d848d9030c48..774ab07d8dc4 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -158,7 +158,6 @@ def optimize(self, func, target=None, params=None): return mod, params - def _set_params(self, params): self._set_params_func(_convert_param_map(params)) diff --git a/src/relay/backend/contrib/dnnl/codegen.cc b/src/relay/backend/contrib/dnnl/codegen.cc index 759a4421bc1d..807714a33bd5 100644 --- a/src/relay/backend/contrib/dnnl/codegen.cc +++ b/src/relay/backend/contrib/dnnl/codegen.cc @@ -31,6 +31,7 @@ #include #include +#include #include "../codegen_c/codegen_c.h" @@ -50,82 +51,109 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase { out_.push_back({node->name_hint(), 0}); } - void VisitExpr_(const TupleGetItemNode* op) final { - // Do nothing - } - void VisitExpr_(const CallNode* call) final { - std::ostringstream decl_stream; - std::ostringstream buf_stream; - // Args: ID - std::vector args; + struct Output { + std::string decl, buf; + int out_size = 1; + std::string out; + }; + + auto generate_body = [=](const CallNode* root_call, const std::string& func_name, + const std::vector& args, + const std::vector& fused_func_args) { + // Make function call with input buffers when visiting arguments + bool first = true; + std::ostringstream arg_stream; + arg_stream << "("; + for (size_t i = 0; i < root_call->args.size(); ++i) { + VisitExpr(root_call->args[i]); + for (auto out : out_) { + if (!first) { + arg_stream << ", "; + } + first = false; + arg_stream << out.first; + } + } + + for (auto arg_name : fused_func_args) { + arg_stream << ", " << arg_name; + } + + // Analyze the output buffer + auto type_node = root_call->checked_type().as(); + CHECK(type_node != nullptr && runtime::TypeMatch(type_node->dtype, kDLFloat, 32)) + << "Only support single output tensor with float type"; + + auto out_shape = GetShape(root_call->checked_type()); + + Output ret; + ret.out = "buf_" + std::to_string(buf_idx_++); + ret.out_size = std::accumulate(out_shape.begin(), out_shape.end(), 1, std::multiplies()); + + this->PrintIndents(); + + std::ostringstream buf_stream; + buf_stream << "float* " << ret.out << " = (float*)std::malloc(4 * " << ret.out_size << ");"; + ret.buf = buf_stream.str(); - // Get the arguments for various DNNL kernels. - if (IsOp(call, "nn.conv2d")) { - decl_stream << "dnnl_conv2d"; - args = Conv2d(call); + arg_stream << ", " << ret.out; + // Attach attribute arguments + for (size_t i = 0; i < args.size(); ++i) { + arg_stream << ", " << args[i]; + } + arg_stream << ");"; + ret.decl = func_name + arg_stream.str(); + + return ret; + }; + + Output ret; + if (auto conv_call = DetectFusedConv2DBiasReLU(call)) { + ret = generate_body(conv_call, "dnnl_fused_conv2d_bias_relu", + FusedConv2dBiasReLU(conv_call), ext_fused_func_args_); + } else if (IsOp(call, "nn.conv2d")) { + ret = generate_body(call, "dnnl_conv2d", Conv2d(call), {}); } else if (IsOp(call, "nn.dense")) { - decl_stream << "dnnl_dense"; - args = Dense(call); + ret = generate_body(call, "dnnl_dense", Dense(call), {}); } else if (IsOp(call, "nn.relu")) { - decl_stream << "dnnl_relu"; - args = Relu(call); + ret = generate_body(call, "dnnl_relu", Relu(call), {}); } else if (IsOp(call, "nn.batch_norm")) { - decl_stream << "dnnl_bn"; - args = BatchNorm(call); + ret = generate_body(call, "dnnl_bn", BatchNorm(call), {}); } else if (IsOp(call, "add")) { - decl_stream << "dnnl_add"; - args = Add(call); + ret = generate_body(call, "dnnl_add", Add(call), {}); } else { LOG(FATAL) << "Unsupported op: " << AsText(call->op, false); } - // Make function call with input buffers when visiting arguments - bool first = true; - decl_stream << "("; - for (size_t i = 0; i < call->args.size(); ++i) { - VisitExpr(call->args[i]); - for (auto out : out_) { - if (!first) { - decl_stream << ", "; - } - first = false; - decl_stream << out.first; - } - } - - // Analyze the output buffer - auto type_node = call->checked_type().as(); - CHECK(type_node != nullptr && runtime::TypeMatch(type_node->dtype, kDLFloat, 32)) - << "Only support single output tensor with float type"; - std::string out = "buf_" + std::to_string(buf_idx_++); - auto out_shape = GetShape(call->checked_type()); - int out_size = 1; - for (size_t i = 0; i < out_shape.size(); ++i) { - out_size *= out_shape[i]; - } - this->PrintIndents(); - buf_stream << "float* " << out << " = (float*)std::malloc(4 * " << out_size << ");"; - buf_decl_.push_back(buf_stream.str()); - decl_stream << ", " << out; - - // Attach attribute arguments - for (size_t i = 0; i < args.size(); ++i) { - decl_stream << ", " << args[i]; - } - decl_stream << ");"; - ext_func_body.push_back(decl_stream.str()); + buf_decl_.push_back(ret.buf); + ext_func_body.push_back(ret.decl); // Update output buffer out_.clear(); - out_.push_back({out, out_size}); + out_.push_back({ret.out, ret.out_size}); } std::string JIT(void) { + ext_func_args_.insert(ext_func_args_.end(), ext_fused_func_args_.begin(), + ext_fused_func_args_.end()); return JitImpl(ext_func_id_, ext_func_args_, buf_decl_, ext_func_body, out_); } private: + const CallNode* DetectFusedConv2DBiasReLU(const CallNode* call) { + if (!IsOp(call, "nn.relu")) return nullptr; + auto relu_arg = call->args[0]; + const CallNode* add_call = relu_arg.as(); + if (!add_call || !IsOp(add_call, "add")) return nullptr; + auto add_arg = add_call->args[0]; + const CallNode* conv_call = add_arg.as(); + if (!conv_call || !IsOp(conv_call, "nn.conv2d")) return nullptr; + auto bias_name = "dnnl_fused_input" + std::to_string(ext_fused_func_args_.size()); + ext_fused_func_args_.push_back(bias_name); + return conv_call; + } + std::vector Conv2d(const CallNode* call) { std::vector args; const auto* conv2d_attr = call->attrs.as(); @@ -152,6 +180,10 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase { return args; } + std::vector FusedConv2dBiasReLU(const CallNode* call) { + return Conv2d(call); + } + std::vector Dense(const CallNode* call) { std::vector args; auto ishape = GetShape(call->args[0]->checked_type()); @@ -214,6 +246,7 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase { int buf_idx_{0}; /*! \brief The arguments used by a wrapped function that calls DNNL kernels. */ std::vector ext_func_args_; + std::vector ext_fused_func_args_; /*! \brief statement of the function that will be compiled using DNNL kernels. */ std::vector ext_func_body; /*! \brief The declaration of intermeidate buffers. */ diff --git a/src/runtime/contrib/dnnl/dnnl.cc b/src/runtime/contrib/dnnl/dnnl.cc index cc430b2c7c76..5622d8feeed1 100644 --- a/src/runtime/contrib/dnnl/dnnl.cc +++ b/src/runtime/contrib/dnnl/dnnl.cc @@ -52,10 +52,10 @@ inline void read_from_dnnl_memory(void* handle, const memory& mem) { std::copy(src, src + bytes, reinterpret_cast(handle)); } -extern "C" void dnnl_conv2d(float* data, float* weights, float* out, int p_N_, - int p_C_, int p_H_, int p_W_, int p_O_, int p_G_, - int p_Ph_, int p_Pw_, int p_Kh_, int p_Kw_, - int p_Sh_, int p_Sw_) { +void dnnl_conv2d_common(float* data, float* weights, float* bias, float* out, + int p_N_, int p_C_, int p_H_, int p_W_, + int p_O_, int p_G_, int p_Ph_, int p_Pw_, int p_Kh_, + int p_Kw_, int p_Sh_, int p_Sw_, primitive_attr attr) { using tag = memory::format_tag; using dt = memory::data_type; engine eng(engine::kind::cpu, 0); @@ -65,21 +65,16 @@ extern "C" void dnnl_conv2d(float* data, float* weights, float* out, int p_N_, memory::dims conv2d_weights_tz = {p_O_, p_C_, p_Kh_, p_Kw_}; if (p_G_ > 1) conv2d_weights_tz = {p_G_, 1, p_C_ / p_G_, p_Kh_, p_Kw_}; memory::dims conv2d_bias_tz = {p_O_}; - memory::dims conv2d_dst_tz = {p_N_, p_O_, - (p_H_ - p_Kh_ + 2 * p_Ph_ + p_Sh_) / p_Sh_, + memory::dims conv2d_dst_tz = {p_N_, p_O_, (p_H_ - p_Kh_ + 2 * p_Ph_ + p_Sh_) / p_Sh_, (p_W_ - p_Kw_ + 2 * p_Pw_ + p_Sw_) / p_Sw_}; memory::dims conv2d_strides = {p_Sh_, p_Sw_}; memory::dims conv2d_padding = {p_Ph_, p_Pw_}; - std::vector conv2d_bias(p_O_, 0); - - auto user_src_memory = - memory({{conv2d_src_tz}, dt::f32, tag::nchw}, eng, data); - auto user_weights_memory = memory( - {{conv2d_weights_tz}, dt::f32, (p_G_ > 1) ? tag::goihw : tag::oihw}, eng, - weights); + auto user_src_memory = memory({{conv2d_src_tz}, dt::f32, tag::nchw}, eng, data); + auto user_weights_memory = + memory({{conv2d_weights_tz}, dt::f32, (p_G_ > 1) ? tag::goihw : tag::oihw}, eng, weights); auto conv2d_user_bias_memory = - memory({{conv2d_bias_tz}, dt::f32, tag::x}, eng, conv2d_bias.data()); + memory({{conv2d_bias_tz}, dt::f32, tag::x}, eng, bias); auto conv2d_src_md = memory::desc({conv2d_src_tz}, dt::f32, tag::any); auto conv2d_bias_md = memory::desc({conv2d_bias_tz}, dt::f32, tag::any); @@ -87,10 +82,9 @@ extern "C" void dnnl_conv2d(float* data, float* weights, float* out, int p_N_, auto conv2d_dst_md = memory::desc({conv2d_dst_tz}, dt::f32, tag::nchw); auto conv2d_desc = convolution_forward::desc( - prop_kind::forward_inference, algorithm::convolution_direct, - conv2d_src_md, conv2d_weights_md, conv2d_bias_md, conv2d_dst_md, - conv2d_strides, conv2d_padding, conv2d_padding); - auto conv2d_prim_desc = convolution_forward::primitive_desc(conv2d_desc, eng); + prop_kind::forward_inference, algorithm::convolution_direct, conv2d_src_md, conv2d_weights_md, + conv2d_bias_md, conv2d_dst_md, conv2d_strides, conv2d_padding, conv2d_padding); + auto conv2d_prim_desc = convolution_forward::primitive_desc(conv2d_desc, attr, eng); auto conv2d_src_memory = user_src_memory; auto conv2d_weights_memory = user_weights_memory; @@ -105,6 +99,39 @@ extern "C" void dnnl_conv2d(float* data, float* weights, float* out, int p_N_, read_from_dnnl_memory(out, conv2d_dst_memory); } +extern "C" void dnnl_conv2d(float* data, float* weights, float* out, + int p_N_, int p_C_, int p_H_, int p_W_, + int p_O_, int p_G_, int p_Ph_, int p_Pw_, + int p_Kh_, int p_Kw_, int p_Sh_, int p_Sw_) { + primitive_attr attr; + std::vector bias(p_O_, 0); + return dnnl_conv2d_common(data, weights, bias.data(), out, + p_N_, p_C_, p_H_, p_W_, p_O_, p_G_, + p_Ph_, p_Pw_, p_Kh_, p_Kw_, p_Sh_, p_Sw_, + attr); +} + +primitive_attr create_attr_with_relu_post_op() { + post_ops ops; + ops.append_eltwise(1.f, algorithm::eltwise_relu, 0.f, 0.f); + + primitive_attr attr; + attr.set_post_ops(ops); + + return attr; +} + +extern "C" void dnnl_fused_conv2d_bias_relu(float* data, float* weights, float* bias, float* out, + int p_N_, int p_C_, int p_H_, int p_W_, int p_O_, + int p_G_, int p_Ph_, int p_Pw_, int p_Kh_, int p_Kw_, + int p_Sh_, int p_Sw_) { + return dnnl_conv2d_common(data, weights, bias, out, + p_N_, p_C_, p_H_, p_W_, + p_O_, p_G_, p_Ph_, p_Pw_, + p_Kh_, p_Kw_, p_Sh_, p_Sw_, + create_attr_with_relu_post_op()); +} + extern "C" void dnnl_dense(float* data, float* weight, float* out, int p_B_, int p_I_, int p_O_) { using tag = memory::format_tag; diff --git a/src/runtime/contrib/dnnl/dnnl_kernel.h b/src/runtime/contrib/dnnl/dnnl_kernel.h index 4d0b100b92ec..b3111030becf 100644 --- a/src/runtime/contrib/dnnl/dnnl_kernel.h +++ b/src/runtime/contrib/dnnl/dnnl_kernel.h @@ -38,6 +38,12 @@ extern "C" TVM_DLL void dnnl_conv2d(float* data, float* weights, float* out, int int p_H_, int p_W_, int p_O_, int p_G_, int p_Ph_, int p_Pw_, int p_Kh_, int p_Kw_, int p_Sh_, int p_Sw_); +extern "C" TVM_DLL void dnnl_fused_conv2d_bias_relu(float* data, float* weights, float* bias, + float* out, int p_N_, int p_C_, int p_H_, + int p_W_, int p_O_, int p_G_, int p_Ph_, + int p_Pw_, int p_Kh_, int p_Kw_, int p_Sh_, + int p_Sw_); + extern "C" TVM_DLL void dnnl_dense(float* data, float* weight, float* out, int p_B_, int p_I_, int p_O_); diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py index 4ffb37311696..4459505595c4 100644 --- a/tests/python/relay/test_pass_partition_graph.py +++ b/tests/python/relay/test_pass_partition_graph.py @@ -27,6 +27,10 @@ from tvm.contrib import util from tvm.relay.annotation import compiler_begin, compiler_end from tvm.relay.expr_functor import ExprMutator +from tvm.relay import analysis, expr as _expr +from tvm.relay.build_module import bind_params_by_name +from tvm.relay.backend import compile_engine + # Leverage the pass manager to write a simple white list based annotator @transform.function_pass(opt_level=0) @@ -165,6 +169,48 @@ def visit_call(self, call): return new_call +class ConvBiasAddReLUAnnotator(ExprMutator): + import enum + state = enum.Enum("State", "Init Conv Bias ReLU") + + def __init__(self, backend): + super().__init__() + self.current_state = self.state.Init + self.backend = backend + + def annotate_call(self, call): + new_args = [] + for arg in call.args: + new_arg = super().visit(arg) + if call.op.name == "nn.conv2d" or isinstance(new_arg, (relay.expr.Var, relay.expr.Constant)): + new_arg = compiler_begin(new_arg, self.backend) + new_args.append(new_arg) + return relay.Call(call.op, new_args, call.attrs, call.type_args) + + def visit_call(self, call): + if call.op.name == "nn.conv2d": + if self.current_state == self.state.Bias: + self.current_state = self.state.Conv + ret = self.annotate_call(call) + self.current_state = self.state.Conv + return ret + self.current_state = self.state.Init + elif call.op.name == "add": + if self.current_state == self.state.ReLU: + self.current_state = self.state.Bias + return self.annotate_call(call) + self.current_state = self.state.Init + elif call.op.name == "nn.relu": + self.current_state = self.state.ReLU + op = self.annotate_call(call) + if self.current_state == self.state.Conv: + op = compiler_end(op, self.backend) + self.current_state = self.state.Init + return op + self.current_state = self.state.Init + return super().visit_call(call) + + def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm", ctx=tvm.cpu(), params=None): if sys.platform == "win32": @@ -425,6 +471,115 @@ def test_extern_dnnl_mobilenet(): (1, 1000), ref_res.asnumpy(), tol=1e-5, params=params) +def test_partition_conv_bias_relu(): + if not tvm.get_global_func("relay.ext.dnnl", True): + print("skip because DNNL codegen is not available") + return + + def get_blocks(prefix, data, in_channel, out_channel, + include_bn=True, include_sigmoid=False): + weight = relay.var(prefix + "weight") + bn_gamma = relay.var(prefix + "bn_gamma") + bn_beta = relay.var(prefix + "bn_beta") + bn_mmean = relay.var(prefix + "bn_mean") + bn_mvar = relay.var(prefix + "bn_var") + + layer = relay.nn.conv2d(data=data, weight=weight, kernel_size=(3, 3), + channels=out_channel, padding=(1, 1)) + if include_bn: + bn_output = relay.nn.batch_norm(layer, bn_gamma, bn_beta, + bn_mmean, bn_mvar) + layer = bn_output[0] + if include_sigmoid: + # dummy layer to prevent pattern detection + layer = relay.sigmoid(layer) + layer = relay.nn.relu(layer) + return layer + + def get_net(include_bn=True, include_sigmoid=False): + data = relay.var("data", relay.TensorType((1, 3, 224, 224), "float32")) + layer1 = get_blocks("layer1_", data, 3, 16, include_bn, include_sigmoid) + layer2 = get_blocks("layer2_", layer1, 16, 16, include_bn, include_sigmoid) + last = layer2 + return relay.Function(relay.analysis.free_vars(last), last) + + def pre_optimize(mod, params): + remove_bn_pass = transform.Sequential([ + relay.transform.InferType(), + relay.transform.SimplifyInference(), + relay.transform.FoldConstant(), + relay.transform.FoldScaleAxis(), + ]) + + # This is required for constant folding + mod["main"] = bind_params_by_name(mod["main"], params) + + with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]): + mod = remove_bn_pass(mod) + + return mod + + def get_partitoned_mod(mod): + mod["main"] = ConvBiasAddReLUAnnotator("dnnl").visit(mod["main"]) + mod = transform.PartitionGraph()(mod) + return mod + + def get_partitions(mod): + partitions = [] + + def visit_func(expr): + if isinstance(expr, _expr.Function) and expr != mod["main"]: + partitions.append(expr) + analysis.post_order_visit(mod["main"], visit_func) + return partitions + + def test_detect_pattern(include_bn, include_sigmoid, num_expected_partition): + net = get_net(include_bn, include_sigmoid) + mod, params = tvm.relay.testing.create_workload(net) + mod = pre_optimize(mod, params) + mod = get_partitoned_mod(mod) + assert(len(get_partitions(mod)) == num_expected_partition) + + def test_partition(): + # conv + bn + relu -> detection succeed + test_detect_pattern(True, False, 2) + # conv + relu -> fail + test_detect_pattern(False, False, 0) + # conv + bn + sigmoid + relu -> fail + test_detect_pattern(True, True, 0) + + def test_partition_mobilenet(): + mod, params = relay.testing.mobilenet.get_workload() + mod = pre_optimize(mod, params) + mod = get_partitoned_mod(mod) + assert(len(get_partitions(mod)) == 27) + + def test_exec(mod, params, ref_mod, ref_params, out_shape): + ishape = (1, 3, 224, 224) + i_data = np.random.randn(*ishape).astype(np.float32) + ref_ex = relay.create_executor("graph", mod=ref_mod, ctx=tvm.cpu(0)) + ref_res = ref_ex.evaluate()(i_data, **ref_params) + compile_engine.get().clear() + + mod = pre_optimize(mod, params) + mod = get_partitoned_mod(mod) + + check_result(mod, {"data": i_data}, + out_shape, ref_res.asnumpy(), tol=1e-5, params=params) + + test_partition() + test_partition_mobilenet() + + net = get_net() + mod, params = tvm.relay.testing.create_workload(net) + ref_mod, ref_params = tvm.relay.testing.create_workload(net) + test_exec(mod, params, ref_mod, ref_params, (1, 16, 224, 224)) + + mod, params = relay.testing.mobilenet.get_workload() + ref_mod, ref_params = relay.testing.mobilenet.get_workload() + test_exec(mod, params, ref_mod, ref_params, (1, 1000)) + + if __name__ == "__main__": test_multi_node_compiler() test_extern_ccompiler_single_op() @@ -432,3 +587,4 @@ def test_extern_dnnl_mobilenet(): test_extern_ccompiler() test_extern_dnnl() test_extern_dnnl_mobilenet() + test_partition_conv_bias_relu()