Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 71 additions & 1 deletion python/tvm/relay/op/contrib/cmsisnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,73 @@ def check_qnn_conv2d(pattern):
and (not is_depthwise or bias_add is not None)
)

def qnn_fully_connected_pattern():
"""Create pattern for qnn.dense with optional Relu."""
qnn_fc = is_op("qnn.dense")(
wildcard(), is_constant(), is_constant(), is_constant(), is_constant(), is_constant()
)
bias_add = is_op("nn.bias_add")(qnn_fc, is_constant())
req = is_op("qnn.requantize")(
qnn_fc | bias_add, is_constant(), is_constant(), is_constant(), is_constant()
)
clip_or_req = req.optional(is_op("clip"))
return clip_or_req

def check_qnn_fully_connected(pattern):
"""Check if the fully connected is supported by CMSIS-NN."""
if str(pattern.op.name) == "clip":
relu = pattern
requantize = relu.args[0]
else:
requantize = pattern
requantize_input = requantize.args[0]
bias_add = None
bias_dtype = "int32"
if str(requantize_input.op.name) == "nn.bias_add":
bias_add = requantize_input
fc = bias_add.args[0]
bias_dtype = bias_add.args[1].checked_type.dtype
else:
fc = requantize_input
fc_input = fc.args[0]
fc_weight = fc.args[1]

# kernel zero_point should be 0
kernel_zp = fc.args[3].data.numpy().item(0)

return (
fc.attrs.out_dtype == "int32"
and fc_input.checked_type.dtype == "int8"
and fc_weight.checked_type.dtype == "int8"
and pattern.checked_type.dtype == "int8"
and bias_dtype == "int32"
and kernel_zp == 0
)

def qnn_avg_pool2d_pattern():
"""Matches average pooling with optional Relu"""
pattern = is_op("cast")(wildcard())
pattern = is_op("nn.avg_pool2d")(pattern)
pattern = is_op("cast")(pattern)
pattern = pattern.optional(is_op("clip"))
return pattern

def check_qnn_avg_pool2d(pattern):
"""Check if avg pool2d is supported by CMSIS-NN."""
in_cast = pattern
out_cast = in_cast.args[0].args[0]
return in_cast.checked_type.dtype == "int8" and out_cast.checked_type.dtype == "int32"

def qnn_max_pool2d_pattern():
"""Matches max pool2d with optional Relu"""
pattern = is_op("nn.max_pool2d")(wildcard())
pattern = pattern.optional(is_op("clip"))
return pattern

def check_qnn_max_pool2d(pattern):
"""Check if max pool2d is supported by CMSIS-NN."""
return True

def binary_op_pattern(op):
"""Matches QNN binary operation"""
return is_op(f"qnn.{op}")(
Expand All @@ -166,8 +233,11 @@ def check_qnn_binary_op(extract):
)

return [
("cmsis-nn.qnn_softmax", qnn_softmax_pattern(), check_qnn_softmax),
("cmsis-nn.qnn_conv2d", qnn_conv2d_pattern(), check_qnn_conv2d),
("cmsis-nn.qnn_fully_connected", qnn_fully_connected_pattern(), check_qnn_fully_connected),
("cmsis-nn.qnn_avg_pool2d", qnn_avg_pool2d_pattern(), check_qnn_avg_pool2d),
("cmsis-nn.qnn_max_pool2d", qnn_max_pool2d_pattern(), check_qnn_max_pool2d),
("cmsis-nn.qnn_mul", binary_op_pattern("mul"), check_qnn_binary_op),
("cmsis-nn.qnn_add", binary_op_pattern("add"), check_qnn_binary_op),
("cmsis-nn.qnn_softmax", qnn_softmax_pattern(), check_qnn_softmax),
]
4 changes: 2 additions & 2 deletions src/relay/backend/contrib/cmsisnn/generate_constants.cc
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,10 @@ class GenerateConstantsMutator : public MixedModeMutator {
int32_t* multiplier = static_cast<int32_t*>(multiplier_nda->data);
int32_t* shift = static_cast<int32_t*>(shift_nda->data);
for (int i = 0; i < out_channels; ++i) {
double effective_output_scale =
double quantized_multiplier =
static_cast<double>(input_scales[i]) / static_cast<double>(output_scale);
std::tie(*(multiplier + i), *(shift + i)) =
tvm::relay::qnn::GetFixedPointMultiplierShift(effective_output_scale);
tvm::relay::qnn::GetFixedPointMultiplierShift(quantized_multiplier);
}

// Create constants from requantization multiplier and shift
Expand Down
225 changes: 204 additions & 21 deletions src/relay/backend/contrib/cmsisnn/relay_to_tir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,17 +84,9 @@ class RelayToTIRVisitor : public MixedModeMutator {

tir::PrimFunc replacement_func(func_signature, body, VoidType(), Map<tir::Var, tir::Buffer>(),
DictAttrs(dict_attrs));

ir_module_->Add(global_var, replacement_func);
}

Array<PrimExpr> CMSISNNDimensions(const Array<PrimExpr>& shape) {
ICHECK(shape.size() == 4) << "Supports only CMSIS-NN shapes of dimension 4.";
return Array<PrimExpr>{ToArg(qnn::get_const_int(shape[0])), ToArg(qnn::get_const_int(shape[1])),
ToArg(qnn::get_const_int(shape[2])),
ToArg(qnn::get_const_int(shape[3]))};
}

void EmitConv2D(const GlobalVar& global_var, const Expr& expr) {
const CallNode* clip_call = nullptr;
const CallNode* requantize_call = nullptr;
Expand Down Expand Up @@ -164,21 +156,15 @@ class RelayToTIRVisitor : public MixedModeMutator {
ToArg(dilation_w), ToArg(dilation_h), ToArg(clip_min),
ToArg(clip_max)};

// cmsis_nn_dims *input_dims (NHWC)
// layout NHWC
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry Ashutosh, I dont still follow what this comment mean. How is this related to the following line ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was to help understand that the following shape is in NHWC layout which will be passed on to CMSIS-NN API eventually.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think might worth saying that explicitly :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ack

Array<PrimExpr> input_shape = conv2d_call->args[0]->type_as<TensorTypeNode>()->shape;
Array<PrimExpr> input_dims = CMSISNNDimensions(input_shape);

// cmsis_nn_dims *filter_dims (OHWI for Conv2D and IHWO for depthwise)
// OHWI for Conv2D and IHWO for depthwise
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, where is the information OHWI is used ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above. Maybe I should explicitly say that this is for CMSIS-NN API?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the information is that the following line produces a tensor is that layout and subsequent access assume that ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, comment got displaced in the process of review and new edits 😆

Array<PrimExpr> filter_shape = conv2d_call->args[1]->type_as<TensorTypeNode>()->shape;
Array<PrimExpr> filter_dims = CMSISNNDimensions(filter_shape);

// cmsis_nn_dims *bias_dims
Array<PrimExpr> bias_shape{1, 1, 1, out_channels};
Array<PrimExpr> bias_dims = CMSISNNDimensions(bias_shape);

// cmsis_nn_dims *output_dims (same order as input_dims)
Array<PrimExpr> output_shape = conv2d_call->type_as<TensorTypeNode>()->shape;
Array<PrimExpr> output_dims = CMSISNNDimensions(output_shape);

int32_t depth_multiplier = -1;
int kernel_pos_o = kernel_layout.find("O");
Expand All @@ -194,7 +180,7 @@ class RelayToTIRVisitor : public MixedModeMutator {
if (depth_multiplier != -1) {
cmsisnn_api = "arm_depthwise_conv_wrapper_s8";
Array<PrimExpr> depthwise_filter_shape{1, filter_shape[0], filter_shape[1], out_channels};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to use something like
int kernel_pos_o = kernel_layout.find("O");
instead of the numbers ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, partly done above. Missed here.

filter_dims = CMSISNNDimensions(depthwise_filter_shape);
filter_shape = depthwise_filter_shape;
}

tvm::Array<PrimExpr> call_ext_args = {tir::StringImm(cmsisnn_api), input, filter, multiplier};
Expand All @@ -216,10 +202,10 @@ class RelayToTIRVisitor : public MixedModeMutator {
ToArg(context_buffer_size)};

scalar_args = tvm::runtime::Concat(context_buffer_args, scalar_args);
scalar_args = tvm::runtime::Concat(scalar_args, input_dims);
scalar_args = tvm::runtime::Concat(scalar_args, filter_dims);
scalar_args = tvm::runtime::Concat(scalar_args, bias_dims);
scalar_args = tvm::runtime::Concat(scalar_args, output_dims);
scalar_args = tvm::runtime::Concat(scalar_args, input_shape);
scalar_args = tvm::runtime::Concat(scalar_args, filter_shape);
scalar_args = tvm::runtime::Concat(scalar_args, bias_shape);
scalar_args = tvm::runtime::Concat(scalar_args, output_shape);
call_ext_args = tvm::runtime::Concat(call_ext_args, scalar_args);

Array<tir::Var> func_signature{input, filter, multiplier, filter_scale};
Expand All @@ -234,6 +220,197 @@ class RelayToTIRVisitor : public MixedModeMutator {
context_buffer_size);
}

void EmitFullyConnected(const GlobalVar& global_var, const Expr& expr) {
const CallNode* clip_call = nullptr;
const CallNode* requantize_call = nullptr;
const CallNode* bias_add_call = nullptr;
const CallNode* fc_call = nullptr;
const CallNode* final_call = expr.as<CallNode>();
const OpNode* final_op = final_call->op.as<OpNode>();
if (final_op->name == "clip") {
clip_call = final_call;
requantize_call = clip_call->args[0].as<CallNode>();
} else {
requantize_call = final_call;
}
const CallNode* requantize_input = requantize_call->args[0].as<CallNode>();
const OpNode* requantize_input_op = requantize_input->op.as<OpNode>();
if (requantize_input_op->name == "nn.bias_add") {
bias_add_call = requantize_input;
fc_call = bias_add_call->args[0].as<CallNode>();
} else {
fc_call = requantize_input;
}

// TIR variables are created in the order they appear in the Relay partitioned function
// %1 = qnn.dense(%input, %weight_const_0, input_zero_point_scalar, kernel_zero_point_scalar,
// %input_scale_scalar, %kernel_scale_scalar)
// %2 = nn.bias_add(%1, %bias_const_1, axis=1)
// %3 = qnn.requantize(%2, %req_input_scale_scalar, %req_input_zero_point_scalar,
// %output_scale_scalar, %output_zero_point_scalar)
// clip(%3, a_min=%min_scalar, a_max=%max_scalar)
tir::Var input("input", DataType::Handle(8));
tir::Var filter("filter", DataType::Handle(8));
tir::Var bias("bias", DataType::Handle(32));
tir::Var output("output", DataType::Handle(8));

// Individual arguments to the structs arguments of the CMSIS-NN API are filled into call_extern
// https://github.com/ARM-software/CMSIS_5/blob/def6f800f95661eb3451d317f7d0dde504f6020d/CMSIS/NN/Source/ConvolutionFunctions/arm_convolve_wrapper_s8.c#L50

// prepare cmsis_nn_fc_params
const DenseAttrs* dense_attrs = fc_call->attrs.as<DenseAttrs>();
int32_t input_offset = -GetScalarFromConstant<int32_t>(fc_call->args[2]);
int32_t filter_offset = -GetScalarFromConstant<int32_t>(fc_call->args[3]);
int32_t output_offset = GetScalarFromConstant<int32_t>(requantize_call->args[4]);
float input_scale = GetScalarFromConstant<float>(requantize_call->args[1]);
float output_scale = GetScalarFromConstant<float>(requantize_call->args[3]);
int32_t out_channels = qnn::get_const_int(dense_attrs->units);
int32_t clip_min, clip_max;
if (clip_call) {
const ClipAttrs* clip_attrs = clip_call->attrs.as<ClipAttrs>();
clip_min = clip_attrs->a_min;
clip_max = clip_attrs->a_max;
} else {
clip_min = -128;
clip_max = 127;
}

double quantized_multiplier =
static_cast<double>(input_scale) / static_cast<double>(output_scale);
auto mult_shift_pair = tvm::relay::qnn::GetFixedPointMultiplierShift(quantized_multiplier);
int32_t multiplier = std::get<0>(mult_shift_pair);
int32_t shift = std::get<1>(mult_shift_pair);

tvm::Array<PrimExpr> scalar_args = {
ToArg(input_offset), ToArg(filter_offset), ToArg(output_offset), ToArg(clip_min),
ToArg(clip_max), ToArg(multiplier), ToArg(shift)};

Array<PrimExpr> input_shape = fc_call->args[0]->type_as<TensorTypeNode>()->shape;
int32_t batch_size = qnn::get_const_int(input_shape[0]);
int32_t in_channels = qnn::get_const_int(input_shape[1]);
Array<PrimExpr> cmsisnn_input_shape{input_shape[0], 1, 1, input_shape[1]};

Array<PrimExpr> cmsisnn_filter_shape{in_channels, 1, 1, out_channels};

Array<PrimExpr> bias_shape{1, 1, 1, out_channels};

Array<PrimExpr> cmsisnn_output_shape{batch_size, 1, 1, out_channels};

tvm::Array<PrimExpr> call_ext_args = {tir::StringImm("arm_fully_connected_s8"), input, filter};
if (bias_add_call) {
call_ext_args.push_back(bias);
}
call_ext_args.push_back(output);

int context_buffer_size = 0;
std::string context_buffer_name = "NULL";
tvm::Array<PrimExpr> context_buffer_args = {tir::StringImm(context_buffer_name),
ToArg(context_buffer_size)};

scalar_args = tvm::runtime::Concat(context_buffer_args, scalar_args);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest to keep the declaration/initiatlization vectors that get concat closer to the concat.

scalar_args = tvm::runtime::Concat(scalar_args, cmsisnn_input_shape);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Clarity] I think we can bring the declaration "cmsisnn_input_shape" closer to here as it is not used before this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We cannot. Its being used to derive another size in the middle. Also, I would prefer to keep concatenation of arguments at one place. Locating all arguments at same place has helped me with debugs.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think Im missing something here :). Isn't it defined in line 291 and is there a use in between ?
Anyway, if it helps debugging, lets have it this way.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I was looking at a different operator. For some of the others, context buffer creation requires this shape. Just to be consistent, I didn't want to change the code layout.

scalar_args = tvm::runtime::Concat(scalar_args, cmsisnn_filter_shape);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Clarity] I think we can bring the declaration "cmsisnn_filter_shape" closer to here as it is not used before this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.

scalar_args = tvm::runtime::Concat(scalar_args, bias_shape);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Clarity] I think we can bring the declaration "bias_shape" closer to here as it is not used before this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.

scalar_args = tvm::runtime::Concat(scalar_args, cmsisnn_output_shape);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Clarity] I think we can bring the declaration "cmsisnn_output_shape" closer to here as it is not used before this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.

call_ext_args = tvm::runtime::Concat(call_ext_args, scalar_args);

Array<tir::Var> func_signature{input, filter};
if (bias_add_call) {
func_signature.push_back(bias);
}
func_signature.push_back(output);
CreatePrimFuncForExtern(global_var, func_signature, call_ext_args, context_buffer_name,
context_buffer_size);
}

void EmitPool2D(const GlobalVar& global_var, const Expr& expr, const String pool_name) {
Call clip, pool;
Call final_call = GetRef<Call>(expr.as<CallNode>());
Op final_op = GetRef<Op>(final_call->op.as<OpNode>());
if (final_op->name == "clip") {
clip = final_call;
Call clip_input = GetRef<Call>(clip->args[0].as<CallNode>());
Op clip_input_op = GetRef<Op>(clip_input->op.as<OpNode>());
if (clip_input_op->name == "cast") {
pool = GetRef<Call>(clip_input->args[0].as<CallNode>());
} else { // max_pool2d
pool = clip_input;
}
} else if (final_op->name == "cast") {
pool = GetRef<Call>(final_call->args[0].as<CallNode>());
} else { // max_pool2d
pool = final_call;
}

// prepare cmsis_nn_pool_params
int32_t stride_h, stride_w, padding_h, padding_w, pool_size_h, pool_size_w;
int32_t clip_min, clip_max;
std::string cmsisnn_api;
if (pool_name == "cmsis-nn.qnn_avg_pool2d") {
cmsisnn_api = "arm_avgpool_s8";
const AvgPool2DAttrs* attrs = pool->attrs.as<AvgPool2DAttrs>();
stride_h = qnn::get_const_int(attrs->strides[0]);
stride_w = qnn::get_const_int(attrs->strides[1]);
padding_h = qnn::get_const_int(attrs->padding[0]);
padding_w = qnn::get_const_int(attrs->padding[1]);
pool_size_h = qnn::get_const_int(attrs->pool_size[0]);
pool_size_w = qnn::get_const_int(attrs->pool_size[1]);
} else {
cmsisnn_api = "arm_max_pool_s8";
const MaxPool2DAttrs* attrs = pool->attrs.as<MaxPool2DAttrs>();
stride_h = qnn::get_const_int(attrs->strides[0]);
stride_w = qnn::get_const_int(attrs->strides[1]);
padding_h = qnn::get_const_int(attrs->padding[0]);
padding_w = qnn::get_const_int(attrs->padding[1]);
pool_size_h = qnn::get_const_int(attrs->pool_size[0]);
pool_size_w = qnn::get_const_int(attrs->pool_size[1]);
}
if (clip.defined()) {
const ClipAttrs* clip_attrs = clip->attrs.as<ClipAttrs>();
clip_min = clip_attrs->a_min;
clip_max = clip_attrs->a_max;
} else {
clip_min = -128;
clip_max = 127;
}

tvm::Array<PrimExpr> scalar_args = {ToArg(stride_h), ToArg(stride_w), ToArg(padding_h),
ToArg(padding_w), ToArg(clip_min), ToArg(clip_max)};

Array<PrimExpr> input_shape = pool->args[0]->type_as<TensorTypeNode>()->shape;
Array<PrimExpr> cmsisnn_input_shape{1, input_shape[1], input_shape[2], input_shape[3]};

Array<PrimExpr> cmsisnn_filter_shape{1, pool_size_h, pool_size_w, 1};

Array<PrimExpr> output_shape = pool->type_as<TensorTypeNode>()->shape;
Array<PrimExpr> cmsisnn_output_shape{1, output_shape[1], output_shape[2], output_shape[3]};

tir::Var input("input", DataType::Handle(8));
tir::Var output("output", DataType::Handle(8));
tvm::Array<PrimExpr> call_ext_args = {tir::StringImm(cmsisnn_api), input, output};

int context_buffer_size = 0;
std::string context_buffer_name = "NULL";
if (pool_name == "cmsisnn.qnn_avg_pool2d") {
// TODO(@Mousius): Need to move this into buffer_size calculations
context_buffer_size = qnn::get_const_int(input_shape[3]) * sizeof(int32_t);
context_buffer_name = "context_buffer_" + std::to_string(context_buffer_id_++);
}
tvm::Array<PrimExpr> context_buffer_args = {tir::StringImm(context_buffer_name),
ToArg(context_buffer_size)};

scalar_args = tvm::runtime::Concat(context_buffer_args, scalar_args);
scalar_args = tvm::runtime::Concat(scalar_args, cmsisnn_input_shape);
scalar_args = tvm::runtime::Concat(scalar_args, cmsisnn_filter_shape);
scalar_args = tvm::runtime::Concat(scalar_args, cmsisnn_output_shape);
call_ext_args = tvm::runtime::Concat(call_ext_args, scalar_args);

Array<tir::Var> func_signature{input, output};

CreatePrimFuncForExtern(global_var, func_signature, call_ext_args, context_buffer_name,
context_buffer_size);
}

void EmitSoftMax(const GlobalVar& global_var, const Expr& expr) {
const CallNode* quantize_call = expr.as<CallNode>();
const CallNode* softmax_call = quantize_call->args[0].as<CallNode>();
Expand Down Expand Up @@ -422,6 +599,12 @@ class RelayToTIRVisitor : public MixedModeMutator {
if (comp_name == "cmsis-nn.qnn_conv2d") {
EmitConv2D(new_global_var, composite_func->body);
}
if (comp_name == "cmsis-nn.qnn_fully_connected") {
EmitFullyConnected(new_global_var, composite_func->body);
}
if (comp_name == "cmsis-nn.qnn_avg_pool2d" || comp_name == "cmsis-nn.qnn_max_pool2d") {
EmitPool2D(new_global_var, composite_func->body, comp_name.value());
}

Array<Expr> args;
for (const auto& arg : call->args) {
Expand Down
Loading