Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 36 additions & 11 deletions python/tvm/relay/op/contrib/cmsisnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,20 +121,14 @@ def check_qnn_conv2d(pattern):
requantize = pattern
requantize_input = requantize.args[0]
bias_add = None
bias_dtype = "int32"
if str(requantize_input.op.name) == "nn.bias_add":
bias_add = requantize_input
conv2d = bias_add.args[0]
bias_dtype = bias_add.args[1].checked_type.dtype
else:
conv2d = requantize_input
conv2d_input = conv2d.args[0]
conv2d_weight = conv2d.args[1]

# kernel zero_point should be 0
kernel_zp = conv2d.args[3].data.numpy()
kernel_zp = [kernel_zp] if kernel_zp.ndim == 0 else kernel_zp

# check if depthwise Conv2D
kernel_layout = conv2d.attrs.kernel_layout
pos_o = kernel_layout.index("O")
Expand All @@ -145,12 +139,43 @@ def check_qnn_conv2d(pattern):
):
is_depthwise = True

# check if dtypes are supported for the following entities
# (input_dtype, weight_dtype, bias_dtype, out_dtype, pattern_dtype)
are_dtypes_valid = False
conv2d_input_dtype = conv2d_input.checked_type.dtype
if bias_add:
bias_dtype = bias_add.args[1].checked_type.dtype
else:
# this is only to enable to following check that validates all sorts of dtypes
bias_dtype = "int32" if conv2d_input_dtype == "int8" else "int64"
valid_dtypes = None
if conv2d_input_dtype == "int8":
valid_dtypes = ("int8", "int8", "int32", "int32", "int8")
elif conv2d_input_dtype == "int16":
valid_dtypes = ("int16", "int8", "int64", "int64", "int16")

if (
conv2d_input_dtype,
conv2d_weight.checked_type.dtype,
bias_dtype,
conv2d.attrs.out_dtype,
pattern.checked_type.dtype,
) == valid_dtypes:
are_dtypes_valid = True

# input_zero_point should be 0 when int16
valid_input_zp = True
if conv2d_input_dtype == "int16" and conv2d.args[2].data.numpy().item(0) != 0:
valid_input_zp = False

# kernel zero_point should be 0
kernel_zp = conv2d.args[3].data.numpy()
kernel_zp = [kernel_zp] if kernel_zp.ndim == 0 else kernel_zp

# combination of all checks to decide if pattern is eligible for partitioning
ret = (
conv2d.attrs.out_dtype == "int32"
and conv2d_input.checked_type.dtype == "int8"
and conv2d_weight.checked_type.dtype == "int8"
and pattern.checked_type.dtype == "int8"
and bias_dtype == "int32"
are_dtypes_valid
and valid_input_zp
and all([zp == 0 for zp in kernel_zp])
and (not is_depthwise or bias_add is not None)
)
Expand Down
80 changes: 73 additions & 7 deletions src/relay/backend/contrib/cmsisnn/buffer_size.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,27 @@ namespace relay {
namespace contrib {
namespace cmsisnn {

int Conv2dBufferSize(Target target, int32_t padding_w, int32_t padding_h, int32_t input_n,
int32_t input_h, int32_t input_c, int32_t output_h, int32_t output_w,
int32_t stride_w, int32_t stride_h, int32_t dilation_w, int32_t dilation_h,
int32_t filter_w, int32_t filter_h) {
int Conv2dBufferSize(bool is_int16, Target target, int32_t padding_w, int32_t padding_h,
int32_t input_n, int32_t input_h, int32_t input_c, int32_t output_h,
int32_t output_w, int32_t stride_w, int32_t stride_h, int32_t dilation_w,
int32_t dilation_h, int32_t filter_w, int32_t filter_h) {
int size = -1;
if (is_int16) {
size = Conv2dBufferSizeInt16(target, padding_w, padding_h, input_n, input_h, input_c, output_h,
output_w, stride_w, stride_h, dilation_w, dilation_h, filter_w,
filter_h);
} else {
size = Conv2dBufferSizeInt8(target, padding_w, padding_h, input_n, input_h, input_c, output_h,
output_w, stride_w, stride_h, dilation_w, dilation_h, filter_w,
filter_h);
}
return size;
}

int Conv2dBufferSizeInt8(Target target, int32_t padding_w, int32_t padding_h, int32_t input_n,
int32_t input_h, int32_t input_c, int32_t output_h, int32_t output_w,
int32_t stride_w, int32_t stride_h, int32_t dilation_w, int32_t dilation_h,
int32_t filter_w, int32_t filter_h) {
bool is1x1 = (padding_w == 0) && (padding_h == 0) && (input_c % 4 == 0) && (stride_w == 1) &&
(stride_h == 1) && (filter_w == 1) && (filter_h == 1) && (dilation_w == 1) &&
(dilation_h == 1);
Expand Down Expand Up @@ -62,9 +79,38 @@ int Conv2dBufferSize(Target target, int32_t padding_w, int32_t padding_h, int32_
return 0;
}

int DepthwiseConv2dBufferSize(Target target, int32_t input_n, int32_t input_c, int32_t output_c,
int32_t filter_w, int32_t filter_h, int32_t dilation_w,
int32_t dilation_h) {
int Conv2dBufferSizeInt16(Target target, int32_t padding_w, int32_t padding_h, int32_t input_n,
int32_t input_h, int32_t input_c, int32_t output_h, int32_t output_w,
int32_t stride_w, int32_t stride_h, int32_t dilation_w,
int32_t dilation_h, int32_t filter_w, int32_t filter_h) {
bool has_mve = target->GetFeature<Bool>("has_mve").value_or(Bool(false));
bool has_dsp = target->GetFeature<Bool>("has_dsp").value_or(Bool(false));

if (has_dsp && !has_mve) {
if ((filter_w * filter_h * input_c < 512) && dilation_w == 1 && dilation_h == 1) {
return (2 * input_c * filter_w * filter_h) * (int32_t)sizeof(int16_t);
}
}
return 0;
}

int DepthwiseConv2dBufferSize(bool is_int16, Target target, int32_t input_n, int32_t input_c,
int32_t output_c, int32_t filter_w, int32_t filter_h,
int32_t dilation_w, int32_t dilation_h, int32_t depth_multiplier) {
int size = -1;
if (is_int16) {
size = DepthwiseConv2dBufferSizeInt16(target, input_n, input_c, output_c, filter_w, filter_h,
dilation_w, dilation_h, depth_multiplier);
} else {
size = DepthwiseConv2dBufferSizeInt8(target, input_n, input_c, output_c, filter_w, filter_h,
dilation_w, dilation_h, depth_multiplier);
}
return size;
}

int DepthwiseConv2dBufferSizeInt8(Target target, int32_t input_n, int32_t input_c, int32_t output_c,
int32_t filter_w, int32_t filter_h, int32_t dilation_w,
int32_t dilation_h, int32_t depth_multiplier) {
bool has_mve = target->GetFeature<Bool>("has_mve").value_or(Bool(false));
bool has_dsp = target->GetFeature<Bool>("has_dsp").value_or(Bool(false));

Expand All @@ -78,6 +124,26 @@ int DepthwiseConv2dBufferSize(Target target, int32_t input_n, int32_t input_c, i
return 0;
}

int DepthwiseConv2dBufferSizeInt16(Target target, int32_t input_n, int32_t input_c,
int32_t output_c, int32_t filter_w, int32_t filter_h,
int32_t dilation_w, int32_t dilation_h,
int32_t depth_multiplier) {
bool has_mve = target->GetFeature<Bool>("has_mve").value_or(Bool(false));
bool has_dsp = target->GetFeature<Bool>("has_dsp").value_or(Bool(false));

if (depth_multiplier == 1 && dilation_w == 1 && dilation_h == 1 &&
filter_w * filter_h * input_c < 512) {
if (has_dsp) {
if (has_mve) {
return 4 * input_c * filter_w * filter_h * (int32_t)sizeof(int16_t) + 8;
} else {
return input_c * filter_w * filter_h * (int32_t)sizeof(int16_t);
}
}
}
return 0;
}

int AvgPoolBufferSize(Target target, int32_t input_c) {
bool has_mve = target->GetFeature<Bool>("has_mve").value_or(Bool(false));
bool has_dsp = target->GetFeature<Bool>("has_dsp").value_or(Bool(false));
Expand Down
36 changes: 29 additions & 7 deletions src/relay/backend/contrib/cmsisnn/buffer_size.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ namespace cmsisnn {
* See:
* https://github.com/ARM-software/CMSIS_5/blob/8c60448c0e1e50e426180b26db9bc31ddf774361/CMSIS/NN/Source/ConvolutionFunctions/arm_convolve_wrapper_s8.c#L108-L127
*
* \param is_int16 - type of conv2d
* \param target - CMSIS-NN Target
* \param padding_w - Width padding
* \param padding_h - Height padding
Expand All @@ -56,16 +57,27 @@ namespace cmsisnn {
*
* \return Size of buffer to allocate for convolution
*/
int Conv2dBufferSize(Target target, int32_t padding_w, int32_t padding_h, int32_t input_n,
int32_t input_h, int32_t input_c, int32_t output_h, int32_t output_w,
int32_t stride_w, int32_t stride_h, int32_t dilation_w, int32_t dilation_h,
int32_t filter_w, int32_t filter_h);
int Conv2dBufferSize(bool is_int16, Target target, int32_t padding_w, int32_t padding_h,
int32_t input_n, int32_t input_h, int32_t input_c, int32_t output_h,
int32_t output_w, int32_t stride_w, int32_t stride_h, int32_t dilation_w,
int32_t dilation_h, int32_t filter_w, int32_t filter_h);

int Conv2dBufferSizeInt8(Target target, int32_t padding_w, int32_t padding_h, int32_t input_n,
int32_t input_h, int32_t input_c, int32_t output_h, int32_t output_w,
int32_t stride_w, int32_t stride_h, int32_t dilation_w, int32_t dilation_h,
int32_t filter_w, int32_t filter_h);

int Conv2dBufferSizeInt16(Target target, int32_t padding_w, int32_t padding_h, int32_t input_n,
int32_t input_h, int32_t input_c, int32_t output_h, int32_t output_w,
int32_t stride_w, int32_t stride_h, int32_t dilation_w,
int32_t dilation_h, int32_t filter_w, int32_t filter_h);

/*!
* \brief Calculates the appropriate buffer size for CMSIS-NN Depthwise Convolutions
* See:
* https://github.com/ARM-software/CMSIS_5/blob/325443e52637b6c7eedbd160d238a6c462e89c9f/CMSIS/NN/Source/ConvolutionFunctions/arm_depthwise_conv_wrapper_s8.c#L115-L129
*
* \param is_int16 - type of conv2d
* \param target - CMSIS-NN Target
* \param input_n - Input batch size
* \param input_c - Input channels
Expand All @@ -74,12 +86,22 @@ int Conv2dBufferSize(Target target, int32_t padding_w, int32_t padding_h, int32_
* \param filter_h - Filter height
* \param dilation_w - Dilation width
* \param dilation_h - Dilation height
* \param depth_multiplier - Depth Multiplier for Depthwise Convolution
*
* \return Size of buffer to allocate for depthwise convolution
*/
int DepthwiseConv2dBufferSize(Target target, int32_t input_n, int32_t input_c, int32_t output_c,
int32_t filter_w, int32_t filter_h, int32_t dilation_w,
int32_t dilation_h);
int DepthwiseConv2dBufferSize(bool is_int16, Target target, int32_t input_n, int32_t input_c,
int32_t output_c, int32_t filter_w, int32_t filter_h,
int32_t dilation_w, int32_t dilation_h, int32_t depth_multiplier);

int DepthwiseConv2dBufferSizeInt8(Target target, int32_t input_n, int32_t input_c, int32_t output_c,
int32_t filter_w, int32_t filter_h, int32_t dilation_w,
int32_t dilation_h, int32_t depth_multiplier);

int DepthwiseConv2dBufferSizeInt16(Target target, int32_t input_n, int32_t input_c,
int32_t output_c, int32_t filter_w, int32_t filter_h,
int32_t dilation_w, int32_t dilation_h,
int32_t depth_multiplier);

/*!
* \brief Calculates the appropriate buffer size for CMSIS-NN Average Pooling
Expand Down
52 changes: 35 additions & 17 deletions src/relay/backend/contrib/cmsisnn/relay_to_tir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,17 +93,17 @@ class RelayToTIRVisitor : public MixedModeMutator {
const Map<tir::Var, tir::Buffer>& buffer_map,
tvm::Array<PrimExpr> call_extern_args,
PrimExpr context_buffer_var = PrimExpr(),
int context_buffer_size = 0) {
int context_buffer_size = 0, int num_bits = 8) {
Map<String, ObjectRef> dict_attrs;
dict_attrs.Set(tvm::attr::kGlobalSymbol, global_var->name_hint);
dict_attrs.Set(tvm::attr::kTarget, target_);
dict_attrs.Set("tir.noalias", Bool(true));

tir::Stmt body = tir::Evaluate(
tvm::tir::Call(DataType::Int(8), tir::builtin::call_extern(), call_extern_args));
tvm::tir::Call(DataType::Int(num_bits), tir::builtin::call_extern(), call_extern_args));

if (context_buffer_size) {
body = tir::Allocate(Downcast<tir::Var>(context_buffer_var), DataType::Int(8),
body = tir::Allocate(Downcast<tir::Var>(context_buffer_var), DataType::Int(num_bits),
{context_buffer_size}, tir::const_true(), body);
}

Expand Down Expand Up @@ -133,6 +133,22 @@ class RelayToTIRVisitor : public MixedModeMutator {
} else {
conv2d_call = requantize_input;
}
int32_t dtype_bits = conv2d_call->args[0]->type_as<TensorTypeNode>()->dtype.bits();

// Determine bitwidth of buffers based on input dtype
int32_t input_bits = 8;
int32_t filter_bits = 8;
int32_t bias_bits = 32;
int32_t output_bits = 8;
int32_t context_buffer_bits = 8;
bool is_int16 = false;
if (dtype_bits == 16) {
is_int16 = true;
input_bits = 16;
bias_bits = 64;
output_bits = 16;
context_buffer_bits = 16;
}

// TIR variables are created in the order they appear in the Relay partitioned function
// %1 = qnn.conv2d(%input, %weight_const_0, input_zero_point_scalar,
Expand All @@ -145,14 +161,14 @@ class RelayToTIRVisitor : public MixedModeMutator {
const int filter_scale_pos = 3;
const int input_scale_pos = bias_add_call ? 5 : 4;
BufferCreator buffer_creator;
tir::Var input = buffer_creator.CreateBufferVar("input", DataType::Handle(8));
tir::Var filter = buffer_creator.CreateBufferVar("filter", DataType::Handle(8));
tir::Var input = buffer_creator.CreateBufferVar("input", DataType::Handle(input_bits));
tir::Var filter = buffer_creator.CreateBufferVar("filter", DataType::Handle(filter_bits));
tir::Var multiplier = buffer_creator.CreateBufferVar("multiplier", DataType::Handle(32));
if (bias_add_call) {
buffer_creator.CreateBufferVar("bias", DataType::Handle(32));
buffer_creator.CreateBufferVar("bias", DataType::Handle(bias_bits));
}
tir::Var shift = buffer_creator.CreateBufferVar("shift", DataType::Handle(32));
tir::Var output = buffer_creator.CreateBufferVar("output", DataType::Handle(8));
tir::Var output = buffer_creator.CreateBufferVar("output", DataType::Handle(output_bits));

// Relay function contains input_scale and filter_scale as function parameters at the following
// locations in the global partitioned function for Conv2D
Expand Down Expand Up @@ -217,10 +233,10 @@ class RelayToTIRVisitor : public MixedModeMutator {
scalar_args.push_back(ToArg(depth_multiplier));

// original filter_layout for depthwise is HWOI
std::string cmsisnn_api = "arm_convolve_wrapper_s8";
std::string cmsisnn_api = is_int16 ? "arm_convolve_wrapper_s16" : "arm_convolve_wrapper_s8";
bool is_depthwise = depth_multiplier != -1;
if (is_depthwise) {
cmsisnn_api = "arm_depthwise_conv_wrapper_s8";
cmsisnn_api = is_int16 ? "arm_depthwise_conv_wrapper_s16" : "arm_depthwise_conv_wrapper_s8";
int filter_pos_h = kernel_layout.find("H");
int filter_pos_w = kernel_layout.find("W");
Array<PrimExpr> depthwise_filter_shape{1, filter_shape[filter_pos_h],
Expand All @@ -242,18 +258,20 @@ class RelayToTIRVisitor : public MixedModeMutator {
Target target = CreateTarget(transform::PassContext::Current());
size_t context_buffer_size;
if (is_depthwise) {
context_buffer_size = DepthwiseConv2dBufferSize(target, input_n, input_c, output_c, filter_w,
filter_h, dilation_w, dilation_h);
context_buffer_size =
DepthwiseConv2dBufferSize(is_int16, target, input_n, input_c, output_c, filter_w,
filter_h, dilation_w, dilation_h, depth_multiplier);
} else {
context_buffer_size = Conv2dBufferSize(target, padding_w, padding_h, input_n, input_h,
input_c, output_h, output_w, stride_w, stride_h,
dilation_w, dilation_h, filter_w, filter_h);
context_buffer_size = Conv2dBufferSize(is_int16, target, padding_w, padding_h, input_n,
input_h, input_c, output_h, output_w, stride_w,
stride_h, dilation_w, dilation_h, filter_w, filter_h);
}

if (context_buffer_size) {
String context_buffer_name = "context_buffer_" + std::to_string(context_buffer_id_++);
context_buffer_var = tir::Var(context_buffer_name,
PointerType(PrimType(DataType::Int(8)), "global.workspace"));
context_buffer_var =
tir::Var(context_buffer_name,
PointerType(PrimType(DataType::Int(context_buffer_bits)), "global.workspace"));
}
tvm::Array<PrimExpr> context_buffer_args = {context_buffer_var, ToArg(context_buffer_size)};

Expand All @@ -266,7 +284,7 @@ class RelayToTIRVisitor : public MixedModeMutator {

CreatePrimFuncForExtern(global_var, buffer_creator.GetPrimFuncParams(),
buffer_creator.GetBufferMap(), call_ext_args, context_buffer_var,
context_buffer_size);
context_buffer_size, context_buffer_bits);
}

void EmitFullyConnected(const GlobalVar& global_var, const Expr& expr) {
Expand Down
4 changes: 3 additions & 1 deletion src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,9 @@ class CodeGenCMSISNN : public codegen::CodeGenCHost {
cmsis_func_name == "arm_elementwise_add_s8") {
CodeGenC::VisitExpr_(op, os);
} else if (cmsis_func_name == "arm_convolve_wrapper_s8" ||
cmsis_func_name == "arm_depthwise_conv_wrapper_s8") {
cmsis_func_name == "arm_convolve_wrapper_s16" ||
cmsis_func_name == "arm_depthwise_conv_wrapper_s8" ||
cmsis_func_name == "arm_depthwise_conv_wrapper_s16") {
EmitConv2D(op);
} else if (cmsis_func_name == "arm_fully_connected_s8") {
EmitFullyConnected(op);
Expand Down
Loading