-
Notifications
You must be signed in to change notification settings - Fork 3.8k
[6/10] Code generation for fully connected layer via CMSIS-NN #9456
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -84,17 +84,9 @@ class RelayToTIRVisitor : public MixedModeMutator { | |
|
|
||
| tir::PrimFunc replacement_func(func_signature, body, VoidType(), Map<tir::Var, tir::Buffer>(), | ||
| DictAttrs(dict_attrs)); | ||
|
|
||
| ir_module_->Add(global_var, replacement_func); | ||
| } | ||
|
|
||
| Array<PrimExpr> CMSISNNDimensions(const Array<PrimExpr>& shape) { | ||
| ICHECK(shape.size() == 4) << "Supports only CMSIS-NN shapes of dimension 4."; | ||
| return Array<PrimExpr>{ToArg(qnn::get_const_int(shape[0])), ToArg(qnn::get_const_int(shape[1])), | ||
| ToArg(qnn::get_const_int(shape[2])), | ||
| ToArg(qnn::get_const_int(shape[3]))}; | ||
| } | ||
|
|
||
| void EmitConv2D(const GlobalVar& global_var, const Expr& expr) { | ||
| const CallNode* clip_call = nullptr; | ||
| const CallNode* requantize_call = nullptr; | ||
|
|
@@ -166,19 +158,15 @@ class RelayToTIRVisitor : public MixedModeMutator { | |
|
|
||
| // cmsis_nn_dims *input_dims (NHWC) | ||
| Array<PrimExpr> input_shape = conv2d_call->args[0]->type_as<TensorTypeNode>()->shape; | ||
| Array<PrimExpr> input_dims = CMSISNNDimensions(input_shape); | ||
|
|
||
| // cmsis_nn_dims *filter_dims (OHWI for Conv2D and IHWO for depthwise) | ||
| Array<PrimExpr> filter_shape = conv2d_call->args[1]->type_as<TensorTypeNode>()->shape; | ||
| Array<PrimExpr> filter_dims = CMSISNNDimensions(filter_shape); | ||
|
|
||
| // cmsis_nn_dims *bias_dims | ||
| Array<PrimExpr> bias_shape{1, 1, 1, out_channels}; | ||
| Array<PrimExpr> bias_dims = CMSISNNDimensions(bias_shape); | ||
|
|
||
| // cmsis_nn_dims *output_dims (same order as input_dims) | ||
| Array<PrimExpr> output_shape = conv2d_call->type_as<TensorTypeNode>()->shape; | ||
| Array<PrimExpr> output_dims = CMSISNNDimensions(output_shape); | ||
|
|
||
| int32_t depth_multiplier = -1; | ||
| int kernel_pos_o = kernel_layout.find("O"); | ||
|
|
@@ -194,7 +182,7 @@ class RelayToTIRVisitor : public MixedModeMutator { | |
| if (depth_multiplier != -1) { | ||
| cmsisnn_api = "arm_depthwise_conv_wrapper_s8"; | ||
| Array<PrimExpr> depthwise_filter_shape{1, filter_shape[0], filter_shape[1], out_channels}; | ||
| filter_dims = CMSISNNDimensions(depthwise_filter_shape); | ||
| filter_shape = depthwise_filter_shape; | ||
| } | ||
|
|
||
| tvm::Array<PrimExpr> call_ext_args = {tir::StringImm(cmsisnn_api), input, filter, multiplier}; | ||
|
|
@@ -216,10 +204,10 @@ class RelayToTIRVisitor : public MixedModeMutator { | |
| ToArg(context_buffer_size)}; | ||
|
|
||
| scalar_args = tvm::runtime::Concat(context_buffer_args, scalar_args); | ||
| scalar_args = tvm::runtime::Concat(scalar_args, input_dims); | ||
| scalar_args = tvm::runtime::Concat(scalar_args, filter_dims); | ||
| scalar_args = tvm::runtime::Concat(scalar_args, bias_dims); | ||
| scalar_args = tvm::runtime::Concat(scalar_args, output_dims); | ||
| scalar_args = tvm::runtime::Concat(scalar_args, input_shape); | ||
| scalar_args = tvm::runtime::Concat(scalar_args, filter_shape); | ||
| scalar_args = tvm::runtime::Concat(scalar_args, bias_shape); | ||
| scalar_args = tvm::runtime::Concat(scalar_args, output_shape); | ||
| call_ext_args = tvm::runtime::Concat(call_ext_args, scalar_args); | ||
|
|
||
| Array<tir::Var> func_signature{input, filter, multiplier, filter_scale}; | ||
|
|
@@ -234,6 +222,114 @@ class RelayToTIRVisitor : public MixedModeMutator { | |
| context_buffer_size); | ||
| } | ||
|
|
||
| void EmitFullyConnected(const GlobalVar& global_var, const Expr& expr) { | ||
| const CallNode* clip_call = nullptr; | ||
| const CallNode* requantize_call = nullptr; | ||
| const CallNode* bias_add_call = nullptr; | ||
| const CallNode* fc_call = nullptr; | ||
| const CallNode* final_call = expr.as<CallNode>(); | ||
| const OpNode* final_op = final_call->op.as<OpNode>(); | ||
| if (final_op->name == "clip") { | ||
| clip_call = final_call; | ||
| requantize_call = clip_call->args[0].as<CallNode>(); | ||
| } else { | ||
| requantize_call = final_call; | ||
| } | ||
| const CallNode* requantize_input = requantize_call->args[0].as<CallNode>(); | ||
| const OpNode* requantize_input_op = requantize_input->op.as<OpNode>(); | ||
| if (requantize_input_op->name == "nn.bias_add") { | ||
| bias_add_call = requantize_input; | ||
| fc_call = bias_add_call->args[0].as<CallNode>(); | ||
| } else { | ||
| fc_call = requantize_input; | ||
| } | ||
|
|
||
| // TIR variables are created in the order they appear in the Relay partitioned function | ||
| // %1 = qnn.dense(%input, %weight_const_0, input_zero_point_scalar, kernel_zero_point_scalar, | ||
| // %input_scale_scalar, %kernel_scale_scalar) | ||
| // %2 = nn.bias_add(%1, %bias_const_1, axis=1) | ||
| // %3 = qnn.requantize(%2, %req_input_scale_scalar, %req_input_zero_point_scalar, | ||
| // %output_scale_scalar, %output_zero_point_scalar) | ||
| // clip(%3, a_min=%min_scalar, a_max=%max_scalar) | ||
| tir::Var input("input", DataType::Handle(8)); | ||
| tir::Var filter("filter", DataType::Handle(8)); | ||
| tir::Var bias("bias", DataType::Handle(32)); | ||
| tir::Var output("output", DataType::Handle(8)); | ||
|
|
||
| // Individual arguments to the structs arguments of the CMSIS-NN API are filled into call_extern | ||
| // https://github.com/ARM-software/CMSIS_5/blob/def6f800f95661eb3451d317f7d0dde504f6020d/CMSIS/NN/Source/ConvolutionFunctions/arm_convolve_wrapper_s8.c#L50 | ||
|
|
||
| // prepare cmsis_nn_fc_params | ||
| const DenseAttrs* dense_attrs = fc_call->attrs.as<DenseAttrs>(); | ||
| int32_t input_offset = -GetScalarFromConstant<int32_t>(fc_call->args[2]); | ||
| int32_t filter_offset = -GetScalarFromConstant<int32_t>(fc_call->args[3]); | ||
| int32_t output_offset = GetScalarFromConstant<int32_t>(requantize_call->args[4]); | ||
| float input_scale = GetScalarFromConstant<float>(requantize_call->args[1]); | ||
| float output_scale = GetScalarFromConstant<float>(requantize_call->args[3]); | ||
| int32_t out_channels = qnn::get_const_int(dense_attrs->units); | ||
| int32_t clip_min, clip_max; | ||
| if (clip_call) { | ||
| const ClipAttrs* clip_attrs = clip_call->attrs.as<ClipAttrs>(); | ||
| clip_min = clip_attrs->a_min; | ||
| clip_max = clip_attrs->a_max; | ||
| } else { | ||
| clip_min = -128; | ||
| clip_max = 127; | ||
| } | ||
|
|
||
| double quantized_multiplier = | ||
| static_cast<double>(input_scale) / static_cast<double>(output_scale); | ||
| auto mult_shift_pair = tvm::relay::qnn::GetFixedPointMultiplierShift(quantized_multiplier); | ||
| int32_t multiplier = std::get<0>(mult_shift_pair); | ||
| int32_t shift = std::get<1>(mult_shift_pair); | ||
|
|
||
| tvm::Array<PrimExpr> scalar_args = { | ||
| ToArg(input_offset), ToArg(filter_offset), ToArg(output_offset), ToArg(clip_min), | ||
| ToArg(clip_max), ToArg(multiplier), ToArg(shift)}; | ||
|
|
||
| // cmsis_nn_dims *input_dims | ||
| Array<PrimExpr> input_shape = fc_call->args[0]->type_as<TensorTypeNode>()->shape; | ||
| int32_t batch_size = qnn::get_const_int(input_shape[0]); | ||
| int32_t in_channels = qnn::get_const_int(input_shape[1]); | ||
| Array<PrimExpr> cmsisnn_input_shape{input_shape[0], 1, 1, input_shape[1]}; | ||
|
|
||
| // cmsis_nn_dims *filter_dims | ||
| Array<PrimExpr> cmsisnn_filter_shape{in_channels, 1, 1, out_channels}; | ||
|
|
||
| // cmsis_nn_dims *bias_dims | ||
| Array<PrimExpr> bias_shape{1, 1, 1, out_channels}; | ||
|
|
||
| // cmsis_nn_dims *output_dims | ||
| Array<PrimExpr> cmsisnn_output_shape{batch_size, 1, 1, out_channels}; | ||
|
|
||
| std::string cmsisnn_api = "arm_fully_connected_s8"; | ||
asparkhi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| tvm::Array<PrimExpr> call_ext_args = {tir::StringImm(cmsisnn_api), input, filter}; | ||
| if (bias_add_call) { | ||
| call_ext_args.push_back(bias); | ||
| } | ||
| call_ext_args.push_back(output); | ||
|
|
||
| int context_buffer_size = 0; | ||
| std::string context_buffer_name = "NULL"; | ||
| tvm::Array<PrimExpr> context_buffer_args = {tir::StringImm(context_buffer_name), | ||
| ToArg(context_buffer_size)}; | ||
|
|
||
| scalar_args = tvm::runtime::Concat(context_buffer_args, scalar_args); | ||
| scalar_args = tvm::runtime::Concat(scalar_args, cmsisnn_input_shape); | ||
| scalar_args = tvm::runtime::Concat(scalar_args, cmsisnn_filter_shape); | ||
| scalar_args = tvm::runtime::Concat(scalar_args, bias_shape); | ||
| scalar_args = tvm::runtime::Concat(scalar_args, cmsisnn_output_shape); | ||
| call_ext_args = tvm::runtime::Concat(call_ext_args, scalar_args); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit : Would it be possible to keep the definitions of sub-catergories that get concatenated closer to here ? So it is easy to follow ?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I didn't get this. Which ones should be close together and are not?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. e.g.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see what you mean
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can't do because of a cross dependency in this case 😕 |
||
|
|
||
| Array<tir::Var> func_signature{input, filter}; | ||
| if (bias_add_call) { | ||
| func_signature.push_back(bias); | ||
| } | ||
| func_signature.push_back(output); | ||
| CreatePrimFuncForExtern(global_var, func_signature, call_ext_args, context_buffer_name, | ||
| context_buffer_size); | ||
| } | ||
|
|
||
| void EmitSoftMax(const GlobalVar& global_var, const Expr& expr) { | ||
| const CallNode* quantize_call = expr.as<CallNode>(); | ||
| const CallNode* softmax_call = quantize_call->args[0].as<CallNode>(); | ||
|
|
@@ -422,6 +518,9 @@ class RelayToTIRVisitor : public MixedModeMutator { | |
| if (comp_name == "cmsis-nn.qnn_conv2d") { | ||
| EmitConv2D(new_global_var, composite_func->body); | ||
| } | ||
| if (comp_name == "cmsis-nn.qnn_fully_connected") { | ||
| EmitFullyConnected(new_global_var, composite_func->body); | ||
| } | ||
|
|
||
| Array<Expr> args; | ||
| for (const auto& arg : call->args) { | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.