-
Notifications
You must be signed in to change notification settings - Fork 3.8k
[7/10] Code generation for Pooling and Fully Connected via CMSIS-NN #9531
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
093f3e1
2ea9511
c4d7c89
a51669a
a05b17d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -84,17 +84,9 @@ class RelayToTIRVisitor : public MixedModeMutator { | |
|
|
||
| tir::PrimFunc replacement_func(func_signature, body, VoidType(), Map<tir::Var, tir::Buffer>(), | ||
| DictAttrs(dict_attrs)); | ||
|
|
||
| ir_module_->Add(global_var, replacement_func); | ||
| } | ||
|
|
||
| Array<PrimExpr> CMSISNNDimensions(const Array<PrimExpr>& shape) { | ||
| ICHECK(shape.size() == 4) << "Supports only CMSIS-NN shapes of dimension 4."; | ||
| return Array<PrimExpr>{ToArg(qnn::get_const_int(shape[0])), ToArg(qnn::get_const_int(shape[1])), | ||
| ToArg(qnn::get_const_int(shape[2])), | ||
| ToArg(qnn::get_const_int(shape[3]))}; | ||
| } | ||
|
|
||
| void EmitConv2D(const GlobalVar& global_var, const Expr& expr) { | ||
| const CallNode* clip_call = nullptr; | ||
| const CallNode* requantize_call = nullptr; | ||
|
|
@@ -164,21 +156,15 @@ class RelayToTIRVisitor : public MixedModeMutator { | |
| ToArg(dilation_w), ToArg(dilation_h), ToArg(clip_min), | ||
| ToArg(clip_max)}; | ||
|
|
||
| // cmsis_nn_dims *input_dims (NHWC) | ||
| // layout NHWC | ||
| Array<PrimExpr> input_shape = conv2d_call->args[0]->type_as<TensorTypeNode>()->shape; | ||
| Array<PrimExpr> input_dims = CMSISNNDimensions(input_shape); | ||
|
|
||
| // cmsis_nn_dims *filter_dims (OHWI for Conv2D and IHWO for depthwise) | ||
| // OHWI for Conv2D and IHWO for depthwise | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here, where is the information OHWI is used ?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as above. Maybe I should explicitly say that this is for CMSIS-NN API?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the information is that the following line produces a tensor is that layout and subsequent access assume that ?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, comment got displaced in the process of review and new edits 😆 |
||
| Array<PrimExpr> filter_shape = conv2d_call->args[1]->type_as<TensorTypeNode>()->shape; | ||
| Array<PrimExpr> filter_dims = CMSISNNDimensions(filter_shape); | ||
|
|
||
| // cmsis_nn_dims *bias_dims | ||
| Array<PrimExpr> bias_shape{1, 1, 1, out_channels}; | ||
| Array<PrimExpr> bias_dims = CMSISNNDimensions(bias_shape); | ||
|
|
||
| // cmsis_nn_dims *output_dims (same order as input_dims) | ||
| Array<PrimExpr> output_shape = conv2d_call->type_as<TensorTypeNode>()->shape; | ||
| Array<PrimExpr> output_dims = CMSISNNDimensions(output_shape); | ||
|
|
||
| int32_t depth_multiplier = -1; | ||
| int kernel_pos_o = kernel_layout.find("O"); | ||
|
|
@@ -194,7 +180,7 @@ class RelayToTIRVisitor : public MixedModeMutator { | |
| if (depth_multiplier != -1) { | ||
| cmsisnn_api = "arm_depthwise_conv_wrapper_s8"; | ||
| Array<PrimExpr> depthwise_filter_shape{1, filter_shape[0], filter_shape[1], out_channels}; | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it possible to use something like
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, partly done above. Missed here. |
||
| filter_dims = CMSISNNDimensions(depthwise_filter_shape); | ||
| filter_shape = depthwise_filter_shape; | ||
| } | ||
|
|
||
| tvm::Array<PrimExpr> call_ext_args = {tir::StringImm(cmsisnn_api), input, filter, multiplier}; | ||
|
|
@@ -216,10 +202,10 @@ class RelayToTIRVisitor : public MixedModeMutator { | |
| ToArg(context_buffer_size)}; | ||
|
|
||
| scalar_args = tvm::runtime::Concat(context_buffer_args, scalar_args); | ||
| scalar_args = tvm::runtime::Concat(scalar_args, input_dims); | ||
| scalar_args = tvm::runtime::Concat(scalar_args, filter_dims); | ||
| scalar_args = tvm::runtime::Concat(scalar_args, bias_dims); | ||
| scalar_args = tvm::runtime::Concat(scalar_args, output_dims); | ||
| scalar_args = tvm::runtime::Concat(scalar_args, input_shape); | ||
| scalar_args = tvm::runtime::Concat(scalar_args, filter_shape); | ||
| scalar_args = tvm::runtime::Concat(scalar_args, bias_shape); | ||
| scalar_args = tvm::runtime::Concat(scalar_args, output_shape); | ||
| call_ext_args = tvm::runtime::Concat(call_ext_args, scalar_args); | ||
|
|
||
| Array<tir::Var> func_signature{input, filter, multiplier, filter_scale}; | ||
|
|
@@ -234,6 +220,197 @@ class RelayToTIRVisitor : public MixedModeMutator { | |
| context_buffer_size); | ||
| } | ||
|
|
||
| void EmitFullyConnected(const GlobalVar& global_var, const Expr& expr) { | ||
| const CallNode* clip_call = nullptr; | ||
| const CallNode* requantize_call = nullptr; | ||
| const CallNode* bias_add_call = nullptr; | ||
| const CallNode* fc_call = nullptr; | ||
| const CallNode* final_call = expr.as<CallNode>(); | ||
| const OpNode* final_op = final_call->op.as<OpNode>(); | ||
| if (final_op->name == "clip") { | ||
| clip_call = final_call; | ||
| requantize_call = clip_call->args[0].as<CallNode>(); | ||
| } else { | ||
| requantize_call = final_call; | ||
| } | ||
| const CallNode* requantize_input = requantize_call->args[0].as<CallNode>(); | ||
| const OpNode* requantize_input_op = requantize_input->op.as<OpNode>(); | ||
| if (requantize_input_op->name == "nn.bias_add") { | ||
| bias_add_call = requantize_input; | ||
| fc_call = bias_add_call->args[0].as<CallNode>(); | ||
| } else { | ||
| fc_call = requantize_input; | ||
| } | ||
|
|
||
| // TIR variables are created in the order they appear in the Relay partitioned function | ||
| // %1 = qnn.dense(%input, %weight_const_0, input_zero_point_scalar, kernel_zero_point_scalar, | ||
| // %input_scale_scalar, %kernel_scale_scalar) | ||
| // %2 = nn.bias_add(%1, %bias_const_1, axis=1) | ||
| // %3 = qnn.requantize(%2, %req_input_scale_scalar, %req_input_zero_point_scalar, | ||
| // %output_scale_scalar, %output_zero_point_scalar) | ||
| // clip(%3, a_min=%min_scalar, a_max=%max_scalar) | ||
| tir::Var input("input", DataType::Handle(8)); | ||
| tir::Var filter("filter", DataType::Handle(8)); | ||
| tir::Var bias("bias", DataType::Handle(32)); | ||
| tir::Var output("output", DataType::Handle(8)); | ||
|
|
||
| // Individual arguments to the structs arguments of the CMSIS-NN API are filled into call_extern | ||
| // https://github.com/ARM-software/CMSIS_5/blob/def6f800f95661eb3451d317f7d0dde504f6020d/CMSIS/NN/Source/ConvolutionFunctions/arm_convolve_wrapper_s8.c#L50 | ||
|
|
||
| // prepare cmsis_nn_fc_params | ||
| const DenseAttrs* dense_attrs = fc_call->attrs.as<DenseAttrs>(); | ||
| int32_t input_offset = -GetScalarFromConstant<int32_t>(fc_call->args[2]); | ||
| int32_t filter_offset = -GetScalarFromConstant<int32_t>(fc_call->args[3]); | ||
| int32_t output_offset = GetScalarFromConstant<int32_t>(requantize_call->args[4]); | ||
| float input_scale = GetScalarFromConstant<float>(requantize_call->args[1]); | ||
| float output_scale = GetScalarFromConstant<float>(requantize_call->args[3]); | ||
| int32_t out_channels = qnn::get_const_int(dense_attrs->units); | ||
| int32_t clip_min, clip_max; | ||
| if (clip_call) { | ||
| const ClipAttrs* clip_attrs = clip_call->attrs.as<ClipAttrs>(); | ||
| clip_min = clip_attrs->a_min; | ||
| clip_max = clip_attrs->a_max; | ||
| } else { | ||
| clip_min = -128; | ||
| clip_max = 127; | ||
| } | ||
|
|
||
| double quantized_multiplier = | ||
| static_cast<double>(input_scale) / static_cast<double>(output_scale); | ||
| auto mult_shift_pair = tvm::relay::qnn::GetFixedPointMultiplierShift(quantized_multiplier); | ||
| int32_t multiplier = std::get<0>(mult_shift_pair); | ||
| int32_t shift = std::get<1>(mult_shift_pair); | ||
|
|
||
| tvm::Array<PrimExpr> scalar_args = { | ||
| ToArg(input_offset), ToArg(filter_offset), ToArg(output_offset), ToArg(clip_min), | ||
| ToArg(clip_max), ToArg(multiplier), ToArg(shift)}; | ||
|
|
||
| Array<PrimExpr> input_shape = fc_call->args[0]->type_as<TensorTypeNode>()->shape; | ||
| int32_t batch_size = qnn::get_const_int(input_shape[0]); | ||
| int32_t in_channels = qnn::get_const_int(input_shape[1]); | ||
| Array<PrimExpr> cmsisnn_input_shape{input_shape[0], 1, 1, input_shape[1]}; | ||
|
|
||
| Array<PrimExpr> cmsisnn_filter_shape{in_channels, 1, 1, out_channels}; | ||
|
|
||
| Array<PrimExpr> bias_shape{1, 1, 1, out_channels}; | ||
|
|
||
| Array<PrimExpr> cmsisnn_output_shape{batch_size, 1, 1, out_channels}; | ||
|
|
||
| tvm::Array<PrimExpr> call_ext_args = {tir::StringImm("arm_fully_connected_s8"), input, filter}; | ||
| if (bias_add_call) { | ||
| call_ext_args.push_back(bias); | ||
| } | ||
| call_ext_args.push_back(output); | ||
|
|
||
| int context_buffer_size = 0; | ||
| std::string context_buffer_name = "NULL"; | ||
| tvm::Array<PrimExpr> context_buffer_args = {tir::StringImm(context_buffer_name), | ||
| ToArg(context_buffer_size)}; | ||
|
|
||
| scalar_args = tvm::runtime::Concat(context_buffer_args, scalar_args); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would suggest to keep the declaration/initiatlization vectors that get concat closer to the concat. |
||
| scalar_args = tvm::runtime::Concat(scalar_args, cmsisnn_input_shape); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [Clarity] I think we can bring the declaration "cmsisnn_input_shape" closer to here as it is not used before this
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We cannot. Its being used to derive another size in the middle. Also, I would prefer to keep concatenation of arguments at one place. Locating all arguments at same place has helped me with debugs.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think Im missing something here :). Isn't it defined in line 291 and is there a use in between ?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry, I was looking at a different operator. For some of the others, context buffer creation requires this shape. Just to be consistent, I didn't want to change the code layout. |
||
| scalar_args = tvm::runtime::Concat(scalar_args, cmsisnn_filter_shape); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [Clarity] I think we can bring the declaration "cmsisnn_filter_shape" closer to here as it is not used before this
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as above. |
||
| scalar_args = tvm::runtime::Concat(scalar_args, bias_shape); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [Clarity] I think we can bring the declaration "bias_shape" closer to here as it is not used before this
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as above. |
||
| scalar_args = tvm::runtime::Concat(scalar_args, cmsisnn_output_shape); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [Clarity] I think we can bring the declaration "cmsisnn_output_shape" closer to here as it is not used before this
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as above. |
||
| call_ext_args = tvm::runtime::Concat(call_ext_args, scalar_args); | ||
|
|
||
| Array<tir::Var> func_signature{input, filter}; | ||
| if (bias_add_call) { | ||
| func_signature.push_back(bias); | ||
| } | ||
| func_signature.push_back(output); | ||
| CreatePrimFuncForExtern(global_var, func_signature, call_ext_args, context_buffer_name, | ||
| context_buffer_size); | ||
| } | ||
|
|
||
| void EmitPool2D(const GlobalVar& global_var, const Expr& expr, const String pool_name) { | ||
| Call clip, pool; | ||
| Call final_call = GetRef<Call>(expr.as<CallNode>()); | ||
| Op final_op = GetRef<Op>(final_call->op.as<OpNode>()); | ||
| if (final_op->name == "clip") { | ||
| clip = final_call; | ||
| Call clip_input = GetRef<Call>(clip->args[0].as<CallNode>()); | ||
| Op clip_input_op = GetRef<Op>(clip_input->op.as<OpNode>()); | ||
| if (clip_input_op->name == "cast") { | ||
| pool = GetRef<Call>(clip_input->args[0].as<CallNode>()); | ||
| } else { // max_pool2d | ||
| pool = clip_input; | ||
| } | ||
| } else if (final_op->name == "cast") { | ||
| pool = GetRef<Call>(final_call->args[0].as<CallNode>()); | ||
| } else { // max_pool2d | ||
| pool = final_call; | ||
| } | ||
|
|
||
| // prepare cmsis_nn_pool_params | ||
| int32_t stride_h, stride_w, padding_h, padding_w, pool_size_h, pool_size_w; | ||
| int32_t clip_min, clip_max; | ||
| std::string cmsisnn_api; | ||
| if (pool_name == "cmsis-nn.qnn_avg_pool2d") { | ||
| cmsisnn_api = "arm_avgpool_s8"; | ||
| const AvgPool2DAttrs* attrs = pool->attrs.as<AvgPool2DAttrs>(); | ||
| stride_h = qnn::get_const_int(attrs->strides[0]); | ||
| stride_w = qnn::get_const_int(attrs->strides[1]); | ||
| padding_h = qnn::get_const_int(attrs->padding[0]); | ||
| padding_w = qnn::get_const_int(attrs->padding[1]); | ||
| pool_size_h = qnn::get_const_int(attrs->pool_size[0]); | ||
| pool_size_w = qnn::get_const_int(attrs->pool_size[1]); | ||
| } else { | ||
| cmsisnn_api = "arm_max_pool_s8"; | ||
| const MaxPool2DAttrs* attrs = pool->attrs.as<MaxPool2DAttrs>(); | ||
| stride_h = qnn::get_const_int(attrs->strides[0]); | ||
| stride_w = qnn::get_const_int(attrs->strides[1]); | ||
| padding_h = qnn::get_const_int(attrs->padding[0]); | ||
| padding_w = qnn::get_const_int(attrs->padding[1]); | ||
| pool_size_h = qnn::get_const_int(attrs->pool_size[0]); | ||
| pool_size_w = qnn::get_const_int(attrs->pool_size[1]); | ||
| } | ||
| if (clip.defined()) { | ||
| const ClipAttrs* clip_attrs = clip->attrs.as<ClipAttrs>(); | ||
| clip_min = clip_attrs->a_min; | ||
| clip_max = clip_attrs->a_max; | ||
| } else { | ||
| clip_min = -128; | ||
| clip_max = 127; | ||
| } | ||
|
|
||
| tvm::Array<PrimExpr> scalar_args = {ToArg(stride_h), ToArg(stride_w), ToArg(padding_h), | ||
| ToArg(padding_w), ToArg(clip_min), ToArg(clip_max)}; | ||
|
|
||
| Array<PrimExpr> input_shape = pool->args[0]->type_as<TensorTypeNode>()->shape; | ||
| Array<PrimExpr> cmsisnn_input_shape{1, input_shape[1], input_shape[2], input_shape[3]}; | ||
|
|
||
| Array<PrimExpr> cmsisnn_filter_shape{1, pool_size_h, pool_size_w, 1}; | ||
|
|
||
| Array<PrimExpr> output_shape = pool->type_as<TensorTypeNode>()->shape; | ||
| Array<PrimExpr> cmsisnn_output_shape{1, output_shape[1], output_shape[2], output_shape[3]}; | ||
|
|
||
| tir::Var input("input", DataType::Handle(8)); | ||
| tir::Var output("output", DataType::Handle(8)); | ||
| tvm::Array<PrimExpr> call_ext_args = {tir::StringImm(cmsisnn_api), input, output}; | ||
|
|
||
| int context_buffer_size = 0; | ||
| std::string context_buffer_name = "NULL"; | ||
| if (pool_name == "cmsisnn.qnn_avg_pool2d") { | ||
| // TODO(@Mousius): Need to move this into buffer_size calculations | ||
| context_buffer_size = qnn::get_const_int(input_shape[3]) * sizeof(int32_t); | ||
| context_buffer_name = "context_buffer_" + std::to_string(context_buffer_id_++); | ||
| } | ||
| tvm::Array<PrimExpr> context_buffer_args = {tir::StringImm(context_buffer_name), | ||
| ToArg(context_buffer_size)}; | ||
|
|
||
| scalar_args = tvm::runtime::Concat(context_buffer_args, scalar_args); | ||
| scalar_args = tvm::runtime::Concat(scalar_args, cmsisnn_input_shape); | ||
| scalar_args = tvm::runtime::Concat(scalar_args, cmsisnn_filter_shape); | ||
| scalar_args = tvm::runtime::Concat(scalar_args, cmsisnn_output_shape); | ||
| call_ext_args = tvm::runtime::Concat(call_ext_args, scalar_args); | ||
|
|
||
| Array<tir::Var> func_signature{input, output}; | ||
|
|
||
| CreatePrimFuncForExtern(global_var, func_signature, call_ext_args, context_buffer_name, | ||
| context_buffer_size); | ||
| } | ||
|
|
||
| void EmitSoftMax(const GlobalVar& global_var, const Expr& expr) { | ||
| const CallNode* quantize_call = expr.as<CallNode>(); | ||
| const CallNode* softmax_call = quantize_call->args[0].as<CallNode>(); | ||
|
|
@@ -422,6 +599,12 @@ class RelayToTIRVisitor : public MixedModeMutator { | |
| if (comp_name == "cmsis-nn.qnn_conv2d") { | ||
| EmitConv2D(new_global_var, composite_func->body); | ||
| } | ||
| if (comp_name == "cmsis-nn.qnn_fully_connected") { | ||
| EmitFullyConnected(new_global_var, composite_func->body); | ||
| } | ||
| if (comp_name == "cmsis-nn.qnn_avg_pool2d" || comp_name == "cmsis-nn.qnn_max_pool2d") { | ||
| EmitPool2D(new_global_var, composite_func->body, comp_name.value()); | ||
| } | ||
|
|
||
| Array<Expr> args; | ||
| for (const auto& arg : call->args) { | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry Ashutosh, I dont still follow what this comment mean. How is this related to the following line ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was to help understand that the following shape is in NHWC layout which will be passed on to CMSIS-NN API eventually.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think might worth saying that explicitly :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ack