Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 53 additions & 14 deletions python/tvm/testing/aot.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import subprocess
import tarfile
import logging
from typing import Any, NamedTuple, Union, Tuple, Optional, List, Dict
from typing import Any, NamedTuple, Union, Tuple, Optional, List, Dict, Callable
import numpy as np

import tvm
Expand Down Expand Up @@ -200,6 +200,7 @@ def _emit_main_prologue(
compiled_models,
interface_api,
use_stack_allocator=True,
debug_last_error=False,
):
if use_stack_allocator:
workspace_define = f"#define WORKSPACE_SIZE ({workspace_bytes}"
Expand Down Expand Up @@ -243,11 +244,28 @@ def _emit_main_prologue(
va_start(args, msg);
vfprintf(stdout, msg, args);
va_end(args);
}\n
TVM_DLL int TVMFuncRegisterGlobal(const char* name, TVMFunctionHandle f, int override) {}
int main(){\n
}
"""
)
if debug_last_error:
main_file.write(
"""\n
tvm_crt_error_t TVMPlatformTimerStart() {
return kTvmErrorFunctionCallNotImplemented;
}
tvm_crt_error_t TVMPlatformTimerStop(double* elapsed_time_seconds) {
return kTvmErrorFunctionCallNotImplemented;
}
const TVMModule* TVMSystemLibEntryPoint(void) { return NULL; }
"""
)
else:
main_file.write(
"""\n
TVM_DLL int TVMFuncRegisterGlobal(const char* name, TVMFunctionHandle f, int override) {}
"""
)
main_file.write("\nint main(){\n")
main_file.write(custom_prologue)


Expand Down Expand Up @@ -332,10 +350,10 @@ def _emit_main_data_setup(main_file, input_map, output_map, mod_name):


def _emit_main_c_interface_call(
main_file, devices, workspace_pool_names, mod_name, use_workspace_io
main_file, devices, workspace_pool_names, mod_name, use_workspace_io, debug_last_error
):
sub_strings = list()
sub_strings.append(f'{_mangle_name(mod_name,"run")}(')
sub_strings.append(f'if ({_mangle_name(mod_name,"run")}(')
if not use_workspace_io:
sub_strings.append(f'&{_mangle_name(mod_name,"inputs")}, ')
sub_strings.append(f'&{_mangle_name(mod_name,"outputs")}, ')
Expand All @@ -346,10 +364,14 @@ def _emit_main_c_interface_call(
# Removing the last two characters that is a comma and a space
sub_strings[-1] = sub_strings[-1][:-2]
# Adding brackets and newline instead
sub_strings[-1] = sub_strings[-1] + ");\n"

sub_strings[-1] = sub_strings[-1] + ") == -1) {\n"
main_file_string = "".join(sub_strings)
main_file.write(main_file_string)
if debug_last_error:
main_file.write(f'\tprintf("ERROR: %s\\n", TVMGetLastError());\n')
main_file.write(f'\tprintf("{AOT_FAILURE_TOKEN}\\n");\n')
main_file.write(f"\treturn -1;\n")
main_file.write("}\n")


def _emit_main_fake_packed_values(main_file):
Expand Down Expand Up @@ -447,13 +469,15 @@ def _emit_main_epilogue(main_file, custom_epilogue):
main_file.write("}\n")


def _emit_main_common_includes(main_file, custom_includes):
def _emit_main_common_includes(main_file, custom_includes, debug_last_error):
main_file.write("#include <stdio.h>\n")
main_file.write("#include <stdarg.h>\n")
main_file.write("#include <stdlib.h>\n")
main_file.write("#include <math.h>\n")
main_file.write('#include "tvm/runtime/c_runtime_api.h"\n')
main_file.write('#include "tvm/runtime/crt/stack_allocator.h"\n')
if debug_last_error:
main_file.write('#include "tvm/runtime/crt/module.h"\n')
for include in custom_includes:
main_file.write(f'#include "{include}"\n')

Expand All @@ -474,12 +498,13 @@ def _create_main(
workspace_bytes,
use_stack_allocator=True,
use_workspace_io=False,
debug_last_error=False,
):
file_path = pathlib.Path(f"{output_path}/" + test_name).resolve()
# create header file
raw_path = file_path.with_suffix(".c").resolve()
with open(raw_path, "w") as main_file:
_emit_main_common_includes(main_file, custom_includes)
_emit_main_common_includes(main_file, custom_includes, debug_last_error)

if interface_api == "c":
for compiled_model in compiled_models:
Expand All @@ -497,6 +522,7 @@ def _create_main(
compiled_models,
interface_api,
use_stack_allocator,
debug_last_error,
)
if use_stack_allocator:
_emit_main_init_memory_manager(main_file)
Expand Down Expand Up @@ -529,6 +555,7 @@ def _create_main(
list(workspace_pool_names.keys()),
model.name,
use_workspace_io,
debug_last_error,
)
else:
_emit_main_fake_packed_values(main_file)
Expand Down Expand Up @@ -701,6 +728,8 @@ def run_and_check(
test_dir: str = None,
verbose: bool = False,
use_workspace_io: bool = False,
debug_last_error: bool = False,
checker: Optional[Callable[[str], bool]] = None,
):
"""
This method uses the original test data and compiled runtime.Modules
Expand Down Expand Up @@ -780,8 +809,12 @@ def run_and_check_body(base_path):
workspace_bytes,
use_stack_allocator,
use_workspace_io,
debug_last_error,
)

if checker and (not checker(base_path)):
return False

# Verify that compiles fine
file_dir = os.path.dirname(os.path.abspath(__file__))
makefile_dir = os.path.join(file_dir, "../../../tests/python/relay/aot")
Expand Down Expand Up @@ -829,11 +862,13 @@ def run_and_check_body(base_path):
with open(run_log_path) as run_log:
assert AOT_SUCCESS_TOKEN in run_log.read()

return True

if test_dir is None:
tmpdir = utils.tempdir()
run_and_check_body(os.path.join(tmpdir.path, "test"))
return run_and_check_body(os.path.join(tmpdir.path, "test"))
else:
run_and_check_body(test_dir)
return run_and_check_body(test_dir)


def compile_and_run(
Expand All @@ -852,7 +887,9 @@ def compile_and_run(
test_dir: str = None,
verbose: bool = False,
schedule_name: str = None,
):
debug_last_error: bool = False,
checker: Optional[Callable[[str], bool]] = None,
) -> bool:
"""This is a wrapper API to compile and run models as test for AoT

Parameters
Expand Down Expand Up @@ -883,7 +920,7 @@ def compile_and_run(
schedule_name=schedule_name,
)

run_and_check(
return run_and_check(
models=compiled_test_mods,
runner=runner,
interface_api=interface_api,
Expand All @@ -893,6 +930,8 @@ def compile_and_run(
data_linkage=data_linkage,
test_dir=test_dir,
verbose=verbose,
debug_last_error=debug_last_error,
checker=checker,
)


Expand Down
2 changes: 2 additions & 0 deletions src/relay/backend/contrib/cmsisnn/compiler_attrs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,13 @@ Target CreateTarget(const tvm::transform::PassContext& ctx) {

String mcpu = cfg.value()->mcpu;
Array<String> mattr = {cfg.value()->mattr};
Bool debug_last_error = cfg.value()->debug_last_error;

Target cmsis_nn_target(TargetJSON{
{"kind", String("cmsis-nn")},
{"mcpu", mcpu},
{"mattr", mattr},
{"debug_last_error", debug_last_error},
});

return cmsis_nn_target;
Expand Down
4 changes: 4 additions & 0 deletions src/relay/backend/contrib/cmsisnn/compiler_attrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ namespace cmsisnn {
struct CMSISNNCompilerConfigNode : public tvm::AttrsNode<CMSISNNCompilerConfigNode> {
String mcpu;
String mattr;
Bool debug_last_error = Bool(false);

TVM_DECLARE_ATTRS(CMSISNNCompilerConfigNode, "ext.attrs.CMSISNNCompilerConfigNode") {
TVM_ATTR_FIELD(mcpu)
Expand All @@ -47,6 +48,9 @@ struct CMSISNNCompilerConfigNode : public tvm::AttrsNode<CMSISNNCompilerConfigNo
TVM_ATTR_FIELD(mattr)
.describe("The attributes to configure CMSIS-NN (i.e. +nodsp, +nomve)")
.set_default("");
TVM_ATTR_FIELD(debug_last_error)
.describe("Whether to enable storing the last error")
.set_default(Bool(false));
}
};

Expand Down
1 change: 1 addition & 0 deletions src/relay/backend/contrib/cmsisnn/target.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ runtime::Module TIRToRuntime(IRModule mod, Target target);
TVM_REGISTER_TARGET_KIND("cmsis-nn", kDLCPU)
.add_attr_option<Array<String>>("mattr")
.add_attr_option<String>("mcpu")
.add_attr_option<Bool>("debug_last_error")
.set_attr<FTVMRelayToTIR>(tvm::attr::kRelayToTIR, RelayToTIR())
.set_attr<FTVMTIRToRuntime>("TIRToRuntime", TIRToRuntime)
.set_target_parser(tvm::target::parsers::cpu::ParseTarget);
Expand Down
59 changes: 41 additions & 18 deletions src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
* specific language governing permissions and limitations
* under the License.
*/
#include <tvm/ir/transform.h>

#include <cmath>
#include <fstream>
#include <map>
Expand All @@ -26,6 +28,7 @@
#include "../../../../runtime/file_utils.h"
#include "../../../../target/source/codegen_c.h"
#include "../../../../target/source/codegen_c_host.h"
#include "compiler_attrs.h"

namespace tvm {
using namespace tir;
Expand All @@ -35,7 +38,9 @@ namespace cmsisnn {

class CodeGenCMSISNN : public codegen::CodeGenCHost {
public:
void Init(bool output_ssa, bool emit_asserts, bool emit_fwd_func_decl, std::string target_str) {
void Init(bool output_ssa, bool emit_asserts, bool emit_fwd_func_decl, std::string target_str,
bool debug_last_error) {
this->debug_last_error = debug_last_error;
std::unordered_set<std::string> devices;
devices.insert("cmsis-nn");
CodeGenCHost::Init(output_ssa, emit_asserts, emit_fwd_func_decl, target_str, devices);
Expand All @@ -49,6 +54,9 @@ class CodeGenCMSISNN : public codegen::CodeGenCHost {
void AddFunction(const PrimFunc& prim_func) { CodeGenC::AddFunction(prim_func); }

private:
/*! * \brief Enable storing the last error */
bool debug_last_error;

/*! * \brief CMSIS-NN context buffer info */
struct CMSISNNContextBuffer {
std::string name;
Expand Down Expand Up @@ -357,13 +365,7 @@ class CodeGenCMSISNN : public codegen::CodeGenCHost {
stream << "&" << filter_dim << ", " << filter_data << ", ";
stream << "&" << bias_dim << ", " << bias_data << ", ";
stream << "&" << output_dim << ", " << output_data << ");\n";
PrintIndent();
stream << "if (status != ARM_CMSIS_NN_SUCCESS) {\n";
PrintIndent();
PrintIndent();
stream << "return -1;\n";
PrintIndent();
stream << "}\n";
EmitErrorCheck();
}

/*! * \brief Emits CMSIS-NN APIs for every call_extern comprising fully connected */
Expand Down Expand Up @@ -426,13 +428,7 @@ class CodeGenCMSISNN : public codegen::CodeGenCHost {
stream << "&" << filter_dim << ", " << filter_data << ", ";
stream << "&" << bias_dim << ", " << bias_data << ", ";
stream << "&" << output_dim << ", " << output_data << ");\n";
PrintIndent();
stream << "if (status != ARM_CMSIS_NN_SUCCESS) {\n";
PrintIndent();
PrintIndent();
stream << "return -1;\n";
PrintIndent();
stream << "}\n";
EmitErrorCheck();
}

/*! * \brief Emits CMSIS-NN APIs for every call_extern comprising pooling ops */
Expand Down Expand Up @@ -480,24 +476,51 @@ class CodeGenCMSISNN : public codegen::CodeGenCHost {
stream << "&" << input_dim << ", " << input_data << ", ";
stream << "&" << filter_dim << ", ";
stream << "&" << output_dim << ", " << output_data << ");\n";
EmitErrorCheck();
}

void EmitErrorCheck() {
auto emit_error = [&](std::string error) {
if (this->debug_last_error) {
stream << "TVMAPISetLastError(\"" << error << "\"); ";
}
};

PrintIndent();
stream << "if (status != ARM_CMSIS_NN_SUCCESS) {\n";
stream << "switch (!status) {\n";
PrintIndent();
stream << "case ARM_CMSIS_NN_SUCCESS: break;\n";
PrintIndent();
stream << "case ARM_CMSIS_NN_ARG_ERROR: ";
emit_error("ARM_CMSIS_NN_ARG_ERROR");
stream << "return -1;\n";
PrintIndent();
stream << "case ARM_CMSIS_NN_NO_IMPL_ERROR: ";
emit_error("ARM_CMSIS_NN_NO_IMPL_ERROR");
stream << "return -1;\n";
PrintIndent();
stream << "}\n";
}
};

static CMSISNNCompilerConfig GetCompilerAttrs() {
auto ctx = tvm::tir::transform::PassContext::Current();
Optional<CMSISNNCompilerConfig> cfg =
ctx->GetConfig<CMSISNNCompilerConfig>("relay.ext.cmsisnn.options");
if (!cfg.defined()) {
return AttrsWithDefaultValues<CMSISNNCompilerConfig>();
}
return cfg.value();
}

runtime::Module TIRToRuntime(IRModule mod, Target target) {
bool output_ssa = false;
bool emit_asserts = false;
bool emit_fwd_func_decl = false;
bool debug_last_error = GetCompilerAttrs()->debug_last_error;
CodeGenCMSISNN codegen;
Array<String> function_names;
codegen.Init(output_ssa, emit_asserts, emit_fwd_func_decl, target->str());

codegen.Init(output_ssa, emit_asserts, emit_fwd_func_decl, target->str(), debug_last_error);
std::vector<std::pair<tvm::GlobalVar, tvm::BaseFunc>> funcs;
for (auto kv : mod->functions) {
funcs.push_back(kv);
Expand Down
Loading