Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions python/tvm/micro/model_library_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,11 +174,14 @@ def _build_function_memory_map(function_metadata):
device_max_workspace = dict()
main_func_metadata = function_metadata[MAIN_FUNC_NAME_STR]
num_targets = len(main_func_metadata.workspace_sizes.items())
from tvm.driver import tvmc # pylint: disable=import-outside-toplevel

external_codegens = tvmc.composite_target.get_codegen_names()
func_entries = []
target_local_entries = dict()
for i in range(num_targets):
target = main_func_metadata.workspace_sizes.items()[i][0]
device_max_workspace[target] = 0
main_target = main_func_metadata.workspace_sizes.items()[i][0]
device_max_workspace[main_target] = 0
for func_name, finfo in function_metadata.items():
if func_name == MAIN_FUNC_NAME_STR:
continue
Expand All @@ -201,8 +204,11 @@ def _build_function_memory_map(function_metadata):
"workspace_size_bytes": int(workspace_size),
}
target_local_entries[func_name].append(target_entry)
if workspace_size > device_max_workspace[target]:
if workspace_size > device_max_workspace.get(target, 0):
device_max_workspace[target] = workspace_size
# TODO(Mousius) - Remove this massive hack when Targets are unified
if target.kind.name in external_codegens:
device_max_workspace[main_target] += int(workspace_size)

for func_name, target_entries_ in target_local_entries.items():
func_entry = {
Expand Down
83 changes: 63 additions & 20 deletions python/tvm/relay/op/contrib/cmsisnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from ...dataflow_pattern import is_constant, is_op, wildcard
from .register import register_pattern_table

tvm._ffi._init_api("relay.ext.cmsisnn.transform", __name__)


def enabled():
return "cmsis-nn" in Target.list_kinds()
Expand Down Expand Up @@ -53,37 +55,85 @@ def partition_for_cmsisnn(mod, params=None, **opts):
transform.InferType(),
transform.MergeComposite(pattern_table()),
transform.AnnotateTarget("cmsis-nn"),
transform.MergeCompilerRegions(),
transform.PartitionGraph(),
GenerateCMSISNNConstants(),
ExtractConstantsFromPartitionedFunction(),
transform.InferType(),
]
)

return seq(mod)


@register_pattern_table("cmsis-nn")
def pattern_table():
"""Get the CMSIS-NN compiler pattern table."""

def softmax_pattern():
def qnn_softmax_pattern():
"""Create pattern for quantized softmax"""
pattern = is_op("qnn.dequantize")(wildcard(), is_constant(), is_constant())
pattern = is_op("nn.softmax")(pattern)
pattern = is_op("qnn.quantize")(pattern, is_constant(), is_constant())
return pattern

def check_quantized_softmax(extract):
def check_qnn_softmax(pattern):
"""Check if softmax is supported by CMSIS-NN."""
dequantize_call = extract.args[0].args[0]
scale = extract.args[1].data.numpy().item(0)
zero_point = extract.args[2].data.numpy().item(0)
dequantize_call = pattern.args[0].args[0]
scale = pattern.args[1].data.numpy().item(0)
zero_point = pattern.args[2].data.numpy().item(0)

# check for dtypes of quantize and dequantize
return (
(scale == 1.0 / 256 and zero_point == -128)
and extract.attrs.out_dtype == "int8"
and pattern.attrs.out_dtype == "int8"
and dequantize_call.args[0].checked_type.dtype == "int8"
)

def qnn_conv2d_pattern():
"""Create pattern for qnn.conv2D with optional fused relu."""
qnn_conv2d = is_op("qnn.conv2d")(
wildcard(), is_constant(), is_constant(), is_constant(), is_constant(), is_constant()
)
bias_add = is_op("nn.bias_add")(qnn_conv2d, is_constant())
req = is_op("qnn.requantize")(
qnn_conv2d | bias_add, is_constant(), is_constant(), is_constant(), is_constant()
)
clip_or_req = req.optional(is_op("clip"))
return clip_or_req

def check_qnn_conv2d(pattern):
"""Check if the Conv2D is supported by CMSIS-NN."""
if str(pattern.op.name) == "clip":
relu = pattern
requantize = relu.args[0]
else:
requantize = pattern
requantize_input = requantize.args[0]
bias_add = None
bias_dtype = "int32"
if str(requantize_input.op.name) == "nn.bias_add":
bias_add = requantize_input
conv2d = bias_add.args[0]
bias_dtype = bias_add.args[1].checked_type.dtype
else:
conv2d = requantize_input
conv2d_input = conv2d.args[0]
conv2d_weight = conv2d.args[1]

# kernel zero_point should be 0
kernel_zp = conv2d.args[3].data.numpy()
kernel_zp = [kernel_zp] if kernel_zp.ndim == 0 else kernel_zp

return (
conv2d.attrs.out_dtype == "int32"
and conv2d.attrs.padding[2] == 0
and conv2d.attrs.padding[3] == 0
and conv2d_input.checked_type.dtype == "int8"
and conv2d_weight.checked_type.dtype == "int8"
and pattern.checked_type.dtype == "int8"
and bias_dtype == "int32"
and all([zp == 0 for zp in kernel_zp])
)

def binary_op_pattern(op):
"""Matches QNN binary operation"""
return is_op(f"qnn.{op}")(
Expand All @@ -97,23 +147,16 @@ def binary_op_pattern(op):
is_constant(),
)

def check_quantized_binary_op(extract):
def check_qnn_binary_op(extract):
"""Check if multiply is supported by CMSIS-NN."""
return (
extract.args[0].checked_type.dtype == "int8"
and extract.args[1].checked_type.dtype == "int8"
)

return [
("cmsis-nn.quantized_softmax", softmax_pattern(), check_quantized_softmax),
(
"cmsis-nn.quantized_mul",
binary_op_pattern("mul"),
check_quantized_binary_op,
),
(
"cmsis-nn.quantized_add",
binary_op_pattern("add"),
check_quantized_binary_op,
),
("cmsis-nn.qnn_softmax", qnn_softmax_pattern(), check_qnn_softmax),
("cmsis-nn.qnn_conv2d", qnn_conv2d_pattern(), check_qnn_conv2d),
("cmsis-nn.qnn_mul", binary_op_pattern("mul"), check_qnn_binary_op),
("cmsis-nn.qnn_add", binary_op_pattern("add"), check_qnn_binary_op),
]
28 changes: 8 additions & 20 deletions src/relay/backend/aot_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -648,25 +648,8 @@ class AOTExecutorCodegen : public MixedModeVisitor {
use_unpacked_api_(target_host->GetAttr<Bool>("unpacked-api").value_or(Bool(false))) {}

LoweredOutput Codegen(relay::Function func, String mod_name) {
AOTOnDemandAllocator initial_aot_allocator;
initial_aot_allocator.Run(func);

// Pre-lowering storage map and memory plan
// TODO(mbs): Why plan memory and update workspace sizes before lowering?
StorageMap initial_storage_map = initial_aot_allocator.GetStorageMap();
StaticMemoryPlan memory_plan(initial_storage_map);

IRModule mod = IRModule::FromExpr(func);

backend::FunctionInfo func_info;

if (memory_plan.defined()) {
// TODO(@electriclilies, @jroesch): remove UpdateMainWorkspaceSize
func_info = tec::UpdateMainWorkspaceSize(mod, targets_, memory_plan->expr_to_storage_info);
mod = WithAttr(mod, "main_func_info", func_info);
}

IRModule lowered_mod = tec::LowerTEPass(mod_name, [this](Function func) {
IRModule lowered_mod = tec::LowerTEPass(mod_name, [this](BaseFunc func) {
// We need to maintain the constant map for external
// functions so we pass this processing function which
// allows us to process each function as we lower it.
Expand All @@ -683,12 +666,17 @@ class AOTExecutorCodegen : public MixedModeVisitor {
auto lowered_main = lowered_mod->Lookup("main");
auto lowered_main_func = GetRef<Function>(lowered_main.as<FunctionNode>());

// Post-lowering storage map for writing main func - this should be the same map as previously
// created, just referencing the new expressions created from lowering
// Post-lowering storage map for writing main func
AOTOnDemandAllocator final_aot_allocator;
final_aot_allocator.Run(lowered_main_func);
storage_device_map_ = final_aot_allocator.GetStorageMap();

// TODO(@electriclilies, @jroesch, @Mousius): remove UpdateMainWorkspaceSize
StaticMemoryPlan memory_plan(storage_device_map_);
backend::FunctionInfo func_info =
tec::UpdateMainWorkspaceSize(lowered_mod, targets_, memory_plan->expr_to_storage_info);
lowered_mod = WithAttr(lowered_mod, "main_func_info", func_info);

for (auto input : lowered_main_func->params) {
input_vars_.push_back(input);
main_signature_.push_back(tir::Var("input", DataType::Handle()));
Expand Down
168 changes: 168 additions & 0 deletions src/relay/backend/contrib/cmsisnn/extract_constants.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@

/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file extract_constant.cc
* \brief Pushes out constants within partitioned functions all the way upto main()
*/

#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>
#include <tvm/runtime/ndarray.h>

#include "../../../qnn/utils.h"
#include "../../../transforms/pattern_utils.h"

namespace tvm {
namespace relay {
namespace contrib {
namespace cmsisnn {

/*!
* \brief This Mutator finds all functions with constants. Constants are replaced with function
* parameter variables. Constants are pushed all the way upto main().
*/
class ExtractConstantsMutator : public MixedModeMutator {
public:
explicit ExtractConstantsMutator(const IRModule& mod) : mod_(mod) {}

private:
String gen_var_name() { return "tvm_var_extract_const_" + std::to_string(var_count_++); }

Expr VisitExpr_(const FunctionNode* function) final {
Function func = GetRef<Function>(function);
function_to_constants_.Set(func, Array<Constant>{});
functions_.push_back(func);
auto new_body = VisitExpr(func->body);
functions_.pop_back();
if (function_to_constants_[func].size()) {
func = Function(FreeVars(new_body), new_body, func->ret_type, FreeTypeVars(new_body, mod_),
func->attrs);
}
return func;
}

Expr Rewrite_(const CallNode* call, const Expr& post) final {
Expr final_call = post;
auto* post_call = post.as<CallNode>();

// Replace Constant arguments with Vars for ML Operators
// Perform this for non-main Call Nodes only
if (!functions_.empty() && call->op.as<OpNode>()) {
Array<Expr> new_args;
for (auto& arg : post_call->args) {
auto* const_arg = arg.as<ConstantNode>();
if (const_arg && !const_arg->is_scalar()) {
Var var_arg = Var(gen_var_name(), const_arg->tensor_type());
new_args.push_back(var_arg);
const Function& last_func = functions_.back();
Array<Constant> fconstants(function_to_constants_[last_func]);
fconstants.push_back(GetRef<Constant>(const_arg));
function_to_constants_.Set(last_func, fconstants);
} else {
new_args.push_back(arg);
}
}
final_call = Call(call->op, new_args, call->attrs, {});
}

// Since the constants are kicked out of partitioned functions
// a new call to global function is needed
if (auto* glob_var_node = post_call->op.as<GlobalVarNode>()) {
auto glob_var = GetRef<GlobalVar>(glob_var_node);
auto glob_func = Downcast<Function>(mod_->Lookup(glob_var));
auto new_glob_func = VisitExpr(glob_func);
if (!new_glob_func.same_as(glob_func)) {
mod_->Update(glob_var, Downcast<Function>(new_glob_func));
Array<Expr> new_args = post_call->args;
ICHECK(function_to_constants_.find(glob_func) != function_to_constants_.end());
for (auto constant : function_to_constants_.at(glob_func)) {
new_args.push_back(constant);
}
final_call = Call(glob_var, new_args);
}
}

// Since the constants are kicked out of the local partitioned functions
// a new call to local function is needed
// Also, pass on the constants to the callee of this function to support nested functions
if (auto* func_node = call->op.as<FunctionNode>()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we refactor the this section and the above ? I see a bit of code duplication and I think the difference is origin of the Function (being a local or a global)

Copy link
Contributor Author

@asparkhi asparkhi Oct 25, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After some thought, except for the VisitExpr(Function) there is nothing common between those two blocks. From readability pov, its better to keep them separate I think.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ack

Function func = GetRef<Function>(func_node);
auto new_func = VisitExpr(func);
if (!new_func.same_as(func)) {
Array<Expr> new_args = post_call->args;
ICHECK(function_to_constants_.find(func) != function_to_constants_.end());
const Function& last_func = functions_.back();
Array<Constant> fconstants(function_to_constants_[last_func]);
for (auto constant : function_to_constants_.at(func)) {
fconstants.push_back(constant);
Var var_arg = Var(gen_var_name(), constant->tensor_type());
new_args.push_back(var_arg);
}
function_to_constants_.Set(last_func, fconstants);
final_call = Call(new_func, new_args);
}
}

return final_call;
}

private:
/* \brief Updated module where all calls have replaced constants with new variables */
IRModule mod_;
/* \brief Maintains mapping of original function to the replaced constants */
Map<Function, Array<Constant>> function_to_constants_;
/* \brief Stack of functions to determine scope while filling up function_to_constants_ */
Array<Function> functions_;
/* \brief Keeps track of variables being created */
int var_count_ = 0;
};

/*! * \brief Kicks out all constants out of the partitioned function into main() */
IRModule ExtractConstants(const IRModule& mod) {
String func_name;
Function func;

auto extract_constants = ExtractConstantsMutator(mod);
Function main_func = Downcast<Function>(mod->Lookup("main"));
auto new_main_body = extract_constants.VisitExpr(main_func->body);
if (!new_main_body.same_as(main_func->body)) {
auto main_var = mod->GetGlobalVar("main");
auto new_main_func = Function(main_func->params, new_main_body, main_func->ret_type,
main_func->type_params, main_func->attrs);
mod->Update(main_var, new_main_func);
}
return mod;
}

transform::Pass ExtractConstantsFromPartitionedFunction() {
runtime::TypedPackedFunc<IRModule(IRModule, transform::PassContext)> pass_func =
[=](IRModule m, transform::PassContext pc) { return ExtractConstants(m); };
return tvm::transform::CreateModulePass(pass_func, 0, "ExtractConstantsFromPartitionedFunction",
{});
}

TVM_REGISTER_GLOBAL("relay.ext.cmsisnn.transform.ExtractConstantsFromPartitionedFunction")
.set_body_typed(ExtractConstantsFromPartitionedFunction);

} // namespace cmsisnn
} // namespace contrib
} // namespace relay
} // namespace tvm
Loading