Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 45 additions & 1 deletion python/tvm/relay/op/contrib/cmsisnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,49 @@ def check_qnn_conv2d(pattern):
and (not is_depthwise or bias_add is not None)
)

def qnn_fully_connected_pattern():
"""Create pattern for qnn.dense with optional relu."""
qnn_fc = is_op("qnn.dense")(
wildcard(), is_constant(), is_constant(), is_constant(), is_constant(), is_constant()
)
bias_add = is_op("nn.bias_add")(qnn_fc, is_constant())
req = is_op("qnn.requantize")(
qnn_fc | bias_add, is_constant(), is_constant(), is_constant(), is_constant()
)
clip_or_req = req.optional(is_op("clip"))
return clip_or_req

def check_qnn_fully_connected(pattern):
"""Check if the fully connected is supported by CMSIS-NN."""
if str(pattern.op.name) == "clip":
relu = pattern
requantize = relu.args[0]
else:
requantize = pattern
requantize_input = requantize.args[0]
bias_add = None
bias_dtype = "int32"
if str(requantize_input.op.name) == "nn.bias_add":
bias_add = requantize_input
fc = bias_add.args[0]
bias_dtype = bias_add.args[1].checked_type.dtype
else:
fc = requantize_input
fc_input = fc.args[0]
fc_weight = fc.args[1]

# kernel zero_point should be 0
kernel_zp = fc.args[3].data.numpy().item(0)

return (
fc.attrs.out_dtype == "int32"
and fc_input.checked_type.dtype == "int8"
and fc_weight.checked_type.dtype == "int8"
and pattern.checked_type.dtype == "int8"
and bias_dtype == "int32"
and kernel_zp == 0
)

def binary_op_pattern(op):
"""Matches QNN binary operation"""
return is_op(f"qnn.{op}")(
Expand All @@ -166,8 +209,9 @@ def check_qnn_binary_op(extract):
)

return [
("cmsis-nn.qnn_softmax", qnn_softmax_pattern(), check_qnn_softmax),
("cmsis-nn.qnn_conv2d", qnn_conv2d_pattern(), check_qnn_conv2d),
("cmsis-nn.qnn_fully_connected", qnn_fully_connected_pattern(), check_qnn_fully_connected),
("cmsis-nn.qnn_mul", binary_op_pattern("mul"), check_qnn_binary_op),
("cmsis-nn.qnn_add", binary_op_pattern("add"), check_qnn_binary_op),
("cmsis-nn.qnn_softmax", qnn_softmax_pattern(), check_qnn_softmax),
]
4 changes: 2 additions & 2 deletions src/relay/backend/contrib/cmsisnn/generate_constants.cc
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,10 @@ class GenerateConstantsMutator : public MixedModeMutator {
int32_t* multiplier = static_cast<int32_t*>(multiplier_nda->data);
int32_t* shift = static_cast<int32_t*>(shift_nda->data);
for (int i = 0; i < out_channels; ++i) {
double effective_output_scale =
double quantized_multiplier =
static_cast<double>(input_scales[i]) / static_cast<double>(output_scale);
std::tie(*(multiplier + i), *(shift + i)) =
tvm::relay::qnn::GetFixedPointMultiplierShift(effective_output_scale);
tvm::relay::qnn::GetFixedPointMultiplierShift(quantized_multiplier);
}

// Create constants from requantization multiplier and shift
Expand Down
133 changes: 116 additions & 17 deletions src/relay/backend/contrib/cmsisnn/relay_to_tir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,17 +84,9 @@ class RelayToTIRVisitor : public MixedModeMutator {

tir::PrimFunc replacement_func(func_signature, body, VoidType(), Map<tir::Var, tir::Buffer>(),
DictAttrs(dict_attrs));

ir_module_->Add(global_var, replacement_func);
}

Array<PrimExpr> CMSISNNDimensions(const Array<PrimExpr>& shape) {
ICHECK(shape.size() == 4) << "Supports only CMSIS-NN shapes of dimension 4.";
return Array<PrimExpr>{ToArg(qnn::get_const_int(shape[0])), ToArg(qnn::get_const_int(shape[1])),
ToArg(qnn::get_const_int(shape[2])),
ToArg(qnn::get_const_int(shape[3]))};
}

void EmitConv2D(const GlobalVar& global_var, const Expr& expr) {
const CallNode* clip_call = nullptr;
const CallNode* requantize_call = nullptr;
Expand Down Expand Up @@ -166,19 +158,15 @@ class RelayToTIRVisitor : public MixedModeMutator {

// cmsis_nn_dims *input_dims (NHWC)
Array<PrimExpr> input_shape = conv2d_call->args[0]->type_as<TensorTypeNode>()->shape;
Array<PrimExpr> input_dims = CMSISNNDimensions(input_shape);

// cmsis_nn_dims *filter_dims (OHWI for Conv2D and IHWO for depthwise)
Array<PrimExpr> filter_shape = conv2d_call->args[1]->type_as<TensorTypeNode>()->shape;
Array<PrimExpr> filter_dims = CMSISNNDimensions(filter_shape);

// cmsis_nn_dims *bias_dims
Array<PrimExpr> bias_shape{1, 1, 1, out_channels};
Array<PrimExpr> bias_dims = CMSISNNDimensions(bias_shape);

// cmsis_nn_dims *output_dims (same order as input_dims)
Array<PrimExpr> output_shape = conv2d_call->type_as<TensorTypeNode>()->shape;
Array<PrimExpr> output_dims = CMSISNNDimensions(output_shape);

int32_t depth_multiplier = -1;
int kernel_pos_o = kernel_layout.find("O");
Expand All @@ -194,7 +182,7 @@ class RelayToTIRVisitor : public MixedModeMutator {
if (depth_multiplier != -1) {
cmsisnn_api = "arm_depthwise_conv_wrapper_s8";
Array<PrimExpr> depthwise_filter_shape{1, filter_shape[0], filter_shape[1], out_channels};
filter_dims = CMSISNNDimensions(depthwise_filter_shape);
filter_shape = depthwise_filter_shape;
}

tvm::Array<PrimExpr> call_ext_args = {tir::StringImm(cmsisnn_api), input, filter, multiplier};
Expand All @@ -216,10 +204,10 @@ class RelayToTIRVisitor : public MixedModeMutator {
ToArg(context_buffer_size)};

scalar_args = tvm::runtime::Concat(context_buffer_args, scalar_args);
scalar_args = tvm::runtime::Concat(scalar_args, input_dims);
scalar_args = tvm::runtime::Concat(scalar_args, filter_dims);
scalar_args = tvm::runtime::Concat(scalar_args, bias_dims);
scalar_args = tvm::runtime::Concat(scalar_args, output_dims);
scalar_args = tvm::runtime::Concat(scalar_args, input_shape);
scalar_args = tvm::runtime::Concat(scalar_args, filter_shape);
scalar_args = tvm::runtime::Concat(scalar_args, bias_shape);
scalar_args = tvm::runtime::Concat(scalar_args, output_shape);
call_ext_args = tvm::runtime::Concat(call_ext_args, scalar_args);

Array<tir::Var> func_signature{input, filter, multiplier, filter_scale};
Expand All @@ -234,6 +222,114 @@ class RelayToTIRVisitor : public MixedModeMutator {
context_buffer_size);
}

void EmitFullyConnected(const GlobalVar& global_var, const Expr& expr) {
const CallNode* clip_call = nullptr;
const CallNode* requantize_call = nullptr;
const CallNode* bias_add_call = nullptr;
const CallNode* fc_call = nullptr;
const CallNode* final_call = expr.as<CallNode>();
const OpNode* final_op = final_call->op.as<OpNode>();
if (final_op->name == "clip") {
clip_call = final_call;
requantize_call = clip_call->args[0].as<CallNode>();
} else {
requantize_call = final_call;
}
const CallNode* requantize_input = requantize_call->args[0].as<CallNode>();
const OpNode* requantize_input_op = requantize_input->op.as<OpNode>();
if (requantize_input_op->name == "nn.bias_add") {
bias_add_call = requantize_input;
fc_call = bias_add_call->args[0].as<CallNode>();
} else {
fc_call = requantize_input;
}

// TIR variables are created in the order they appear in the Relay partitioned function
// %1 = qnn.dense(%input, %weight_const_0, input_zero_point_scalar, kernel_zero_point_scalar,
// %input_scale_scalar, %kernel_scale_scalar)
// %2 = nn.bias_add(%1, %bias_const_1, axis=1)
// %3 = qnn.requantize(%2, %req_input_scale_scalar, %req_input_zero_point_scalar,
// %output_scale_scalar, %output_zero_point_scalar)
// clip(%3, a_min=%min_scalar, a_max=%max_scalar)
tir::Var input("input", DataType::Handle(8));
tir::Var filter("filter", DataType::Handle(8));
tir::Var bias("bias", DataType::Handle(32));
tir::Var output("output", DataType::Handle(8));

// Individual arguments to the structs arguments of the CMSIS-NN API are filled into call_extern
// https://github.com/ARM-software/CMSIS_5/blob/def6f800f95661eb3451d317f7d0dde504f6020d/CMSIS/NN/Source/ConvolutionFunctions/arm_convolve_wrapper_s8.c#L50

// prepare cmsis_nn_fc_params
const DenseAttrs* dense_attrs = fc_call->attrs.as<DenseAttrs>();
int32_t input_offset = -GetScalarFromConstant<int32_t>(fc_call->args[2]);
int32_t filter_offset = -GetScalarFromConstant<int32_t>(fc_call->args[3]);
int32_t output_offset = GetScalarFromConstant<int32_t>(requantize_call->args[4]);
float input_scale = GetScalarFromConstant<float>(requantize_call->args[1]);
float output_scale = GetScalarFromConstant<float>(requantize_call->args[3]);
int32_t out_channels = qnn::get_const_int(dense_attrs->units);
int32_t clip_min, clip_max;
if (clip_call) {
const ClipAttrs* clip_attrs = clip_call->attrs.as<ClipAttrs>();
clip_min = clip_attrs->a_min;
clip_max = clip_attrs->a_max;
} else {
clip_min = -128;
clip_max = 127;
}

double quantized_multiplier =
static_cast<double>(input_scale) / static_cast<double>(output_scale);
auto mult_shift_pair = tvm::relay::qnn::GetFixedPointMultiplierShift(quantized_multiplier);
int32_t multiplier = std::get<0>(mult_shift_pair);
int32_t shift = std::get<1>(mult_shift_pair);

tvm::Array<PrimExpr> scalar_args = {
ToArg(input_offset), ToArg(filter_offset), ToArg(output_offset), ToArg(clip_min),
ToArg(clip_max), ToArg(multiplier), ToArg(shift)};

// cmsis_nn_dims *input_dims
Array<PrimExpr> input_shape = fc_call->args[0]->type_as<TensorTypeNode>()->shape;
int32_t batch_size = qnn::get_const_int(input_shape[0]);
int32_t in_channels = qnn::get_const_int(input_shape[1]);
Array<PrimExpr> cmsisnn_input_shape{input_shape[0], 1, 1, input_shape[1]};

// cmsis_nn_dims *filter_dims
Array<PrimExpr> cmsisnn_filter_shape{in_channels, 1, 1, out_channels};

// cmsis_nn_dims *bias_dims
Array<PrimExpr> bias_shape{1, 1, 1, out_channels};

// cmsis_nn_dims *output_dims
Array<PrimExpr> cmsisnn_output_shape{batch_size, 1, 1, out_channels};

std::string cmsisnn_api = "arm_fully_connected_s8";
tvm::Array<PrimExpr> call_ext_args = {tir::StringImm(cmsisnn_api), input, filter};
if (bias_add_call) {
call_ext_args.push_back(bias);
}
call_ext_args.push_back(output);

int context_buffer_size = 0;
std::string context_buffer_name = "NULL";
tvm::Array<PrimExpr> context_buffer_args = {tir::StringImm(context_buffer_name),
ToArg(context_buffer_size)};

scalar_args = tvm::runtime::Concat(context_buffer_args, scalar_args);
scalar_args = tvm::runtime::Concat(scalar_args, cmsisnn_input_shape);
scalar_args = tvm::runtime::Concat(scalar_args, cmsisnn_filter_shape);
scalar_args = tvm::runtime::Concat(scalar_args, bias_shape);
scalar_args = tvm::runtime::Concat(scalar_args, cmsisnn_output_shape);
call_ext_args = tvm::runtime::Concat(call_ext_args, scalar_args);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit : Would it be possible to keep the definitions of sub-catergories that get concatenated closer to here ? So it is easy to follow ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't get this. Which ones should be close together and are not?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

e.g.

Array<PrimExpr> cmsisnn_input_shape{input_shape[0], 1, 1, input_shape[1]};
scalar_args = tvm::runtime::Concat(scalar_args, cmsisnn_input_shape);

Array<PrimExpr> cmsisnn_filter_shape{in_channels, 1, 1, out_channels};
scalar_args = tvm::runtime::Concat(scalar_args, cmsisnn_filter_shape);

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see what you mean

Copy link
Contributor Author

@asparkhi asparkhi Dec 2, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't do because of a cross dependency in this case 😕


Array<tir::Var> func_signature{input, filter};
if (bias_add_call) {
func_signature.push_back(bias);
}
func_signature.push_back(output);
CreatePrimFuncForExtern(global_var, func_signature, call_ext_args, context_buffer_name,
context_buffer_size);
}

void EmitSoftMax(const GlobalVar& global_var, const Expr& expr) {
const CallNode* quantize_call = expr.as<CallNode>();
const CallNode* softmax_call = quantize_call->args[0].as<CallNode>();
Expand Down Expand Up @@ -422,6 +518,9 @@ class RelayToTIRVisitor : public MixedModeMutator {
if (comp_name == "cmsis-nn.qnn_conv2d") {
EmitConv2D(new_global_var, composite_func->body);
}
if (comp_name == "cmsis-nn.qnn_fully_connected") {
EmitFullyConnected(new_global_var, composite_func->body);
}

Array<Expr> args;
for (const auto& arg : call->args) {
Expand Down
Loading