Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions python/tvm/relay/testing/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,16 @@ def conv2d_single_function(ifm_tensor):

return conv2d_single_function

def load_from_file(self, model_file, shapes):
"""Load tflite model from a tflite file"""
for i, shape in enumerate(shapes):
input_name = "input_" + str(i)
self.shape_dict.update({input_name: shape})
self.dtype_dict.update({input_name: self.dtype})

with open(model_file, "rb") as f:
self.serial_model = f.read()

def create_tflite_model(self, tfl_function, shapes, ranges=None):
"""Creates TFLite serial graph"""
tensor_specs = []
Expand Down
7 changes: 4 additions & 3 deletions src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ namespace cmsisnn {

class CodeGenCMSISNN : public codegen::CodeGenCHost {
public:
void Init(bool output_ssa, bool emit_asserts, std::string target_str) {
void Init(bool output_ssa, bool emit_asserts, bool emit_fwd_func_decl, std::string target_str) {
std::unordered_set<std::string> devices;
devices.insert("cmsis-nn");
CodeGenCHost::Init(output_ssa, emit_asserts, target_str, devices);
CodeGenCHost::Init(output_ssa, emit_asserts, emit_fwd_func_decl, target_str, devices);
}

/*!
Expand Down Expand Up @@ -491,9 +491,10 @@ class CodeGenCMSISNN : public codegen::CodeGenCHost {
runtime::Module TIRToRuntime(IRModule mod, Target target) {
bool output_ssa = false;
bool emit_asserts = false;
bool emit_fwd_func_decl = false;
CodeGenCMSISNN codegen;
Array<String> function_names;
codegen.Init(output_ssa, emit_asserts, target->str());
codegen.Init(output_ssa, emit_asserts, emit_fwd_func_decl, target->str());

std::vector<std::pair<tvm::GlobalVar, tvm::BaseFunc>> funcs;
for (auto kv : mod->functions) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,11 @@ class CodeGenExampleTargetHook : public codegen::CodeGenCHost {
runtime::Module TIRToRuntime(IRModule mod, Target target) {
bool output_ssa = false;
bool emit_asserts = false;
bool emit_fwd_func_decl = false;
CodeGenExampleTargetHook codegen;
Array<String> function_names;
std::unordered_set<std::string> devices;
codegen.Init(output_ssa, emit_asserts, target->str(), devices);
codegen.Init(output_ssa, emit_asserts, emit_fwd_func_decl, target->str(), devices);
for (auto kv : mod->functions) {
auto prim_func = Downcast<PrimFunc>(kv.second);
auto global_symbol = prim_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
Expand Down
7 changes: 4 additions & 3 deletions src/relay/backend/contrib/uma/tir_to_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class UMACodegen : public codegen::CodeGenCHost {
public:
explicit UMACodegen(String target_str) : target_str_(target_str) {}

void Init(bool output_ssa, bool emit_asserts) {
void Init(bool output_ssa, bool emit_asserts, bool emit_fwd_func_decl) {
auto includes_pf =
tvm::runtime::Registry::Get("relay.ext.uma.codegen_c_includes_" + target_str_);
if (includes_pf) {
Expand All @@ -46,7 +46,7 @@ class UMACodegen : public codegen::CodeGenCHost {
}
std::unordered_set<std::string> devices;
devices.insert(target_str_);
CodeGenCHost::Init(output_ssa, emit_asserts, target_str_, devices);
CodeGenCHost::Init(output_ssa, emit_asserts, emit_fwd_func_decl, target_str_, devices);
}

/*!
Expand All @@ -63,9 +63,10 @@ class UMACodegen : public codegen::CodeGenCHost {
runtime::Module TIRToRuntime(IRModule mod, Target target) {
bool output_ssa = false;
bool emit_asserts = false;
bool emit_fwd_func_decl = false;
UMACodegen codegen(target->kind->name);
Array<String> function_names;
codegen.Init(output_ssa, emit_asserts);
codegen.Init(output_ssa, emit_asserts, emit_fwd_func_decl);
for (auto kv : mod->functions) {
auto prim_func = Downcast<PrimFunc>(kv.second);
auto global_symbol = prim_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
Expand Down
5 changes: 3 additions & 2 deletions src/target/source/codegen_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ void CodeGenC::AddFunction(const PrimFunc& f) {
<< "CodeGenC: Expect PrimFunc to have the global_symbol attribute";
bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias);

this->PrintFuncPrefix();
this->PrintFuncPrefix(stream);
this->PrintExtraAttrs(f);
this->stream << " " << static_cast<std::string>(global_symbol.value()) << "(";

Expand Down Expand Up @@ -127,7 +127,7 @@ void CodeGenC::AddFunction(const PrimFunc& f) {
this->stream << "}\n\n";
}

void CodeGenC::PrintFuncPrefix() { stream << "void"; }
void CodeGenC::PrintFuncPrefix(std::ostream& os) { os << "void"; }

void CodeGenC::PrintExtraAttrs(const PrimFunc& f) {}

Expand Down Expand Up @@ -540,6 +540,7 @@ void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*)
ICHECK_GE(op->args.size(), 1U);
auto func = Downcast<StringImm>(op->args[0]);
this->PrintCallExtern(GetType(GetRef<PrimExpr>(op)), func->value, op->args, true, os);
this->GenerateForwardFunctionDeclarations(func->value, op->args);
} else if (op_attr_global_symbol_.count(call_op)) {
// call extern if the op itself have a global symbol.
this->PrintCallExtern(GetType(GetRef<PrimExpr>(op)), op_attr_global_symbol_[call_op],
Expand Down
13 changes: 11 additions & 2 deletions src/target/source/codegen_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class CodeGenC : public ExprFunctor<void(const PrimExpr&, std::ostream&)>,
* \brief Finalize the compilation and return the code.
* \return The code.
*/
std::string Finish();
virtual std::string Finish();
/*!
* \brief Print the Stmt n to CodeGenC->stream
* \param n The statement to be printed.
Expand All @@ -99,10 +99,11 @@ class CodeGenC : public ExprFunctor<void(const PrimExpr&, std::ostream&)>,
// The following parts are overloadable print operations.
/*!
* \brief Print the function header before the argument list
* \param os The output stream
*
* Example: stream << "void";
*/
virtual void PrintFuncPrefix(); // NOLINT(*)
virtual void PrintFuncPrefix(std::ostream& os); // NOLINT(*)
/*!
* \brief Print extra function attributes
*
Expand Down Expand Up @@ -230,6 +231,14 @@ class CodeGenC : public ExprFunctor<void(const PrimExpr&, std::ostream&)>,
*/
virtual bool IsScopePartOfType() const { return true; }

/*!
* \brief Generate forward function declarations.
* \param global_symbol The symbolc of the target function.
* \param args The arguments to the function.
* \param os The output stream.
*/
virtual void GenerateForwardFunctionDeclarations(String global_symbol,
const Array<PrimExpr>& args) {}
/*!
* \brief Print external function call.
* \param ret_type The return type.
Expand Down
57 changes: 46 additions & 11 deletions src/target/source/codegen_c_host.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,10 @@ namespace codegen {

CodeGenCHost::CodeGenCHost() { module_name_ = name_supply_->FreshName("__tvm_module_ctx"); }

void CodeGenCHost::Init(bool output_ssa, bool emit_asserts, std::string target_str,
const std::unordered_set<std::string>& devices) {
void CodeGenCHost::Init(bool output_ssa, bool emit_asserts, bool emit_fwd_func_decl,
std::string target_str, const std::unordered_set<std::string>& devices) {
emit_asserts_ = emit_asserts;
emit_fwd_func_decl_ = emit_fwd_func_decl;
declared_globals_.clear();
decl_stream << "// tvm target: " << target_str << "\n";
decl_stream << "#define TVM_EXPORTS\n";
Expand All @@ -73,17 +74,18 @@ void CodeGenCHost::InitGlobalContext() {

void CodeGenCHost::DefineModuleName() { decl_stream << "void* " << module_name_ << " = NULL;\n"; }

void CodeGenCHost::AddFunction(const PrimFunc& f) {
void CodeGenCHost::AddFunction(const PrimFunc& f, bool emit_fwd_func_decl) {
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
ICHECK(global_symbol.defined())
<< "CodeGenCHost: Expect PrimFunc to have the global_symbol attribute";
function_names_.push_back(global_symbol.value());

emit_fwd_func_decl_ = emit_fwd_func_decl;
CodeGenC::AddFunction(f);
if (f->HasNonzeroAttr(tir::attr::kIsEntryFunc)) {
function_names_.push_back(runtime::symbol::tvm_module_main);
stream << "// CodegenC: NOTE: Auto-generated entry function\n";
PrintFuncPrefix();
PrintFuncPrefix(stream);
stream << " " << tvm::runtime::symbol::tvm_module_main
<< "(void* args, int* arg_type_ids, int num_args, void* out_ret_value, "
<< "int* out_ret_tcode, void* resource_handle) {\n";
Expand All @@ -93,18 +95,49 @@ void CodeGenCHost::AddFunction(const PrimFunc& f) {
}
}

void CodeGenCHost::PrintFuncPrefix() { // NOLINT(*)
stream << "#ifdef __cplusplus\n"
<< "extern \"C\"\n"
<< "#endif\n"
<< "TVM_DLL int32_t";
void CodeGenCHost::GenerateForwardFunctionDeclarations(String global_symbol,
const Array<PrimExpr>& args) {
if (!emit_fwd_func_decl_) {
return;
}
for (auto& func_already_defined : GetFunctionNames()) {
if (global_symbol == func_already_defined) {
return;
}
}
this->PrintFuncPrefix(fwd_decl_stream);
fwd_decl_stream << " " << global_symbol << "(";
for (size_t i = 1; i < args.size(); ++i) {
CodeGenSourceBase::PrintType(GetType(args[i]), fwd_decl_stream);
fwd_decl_stream << " ", this->PrintExpr(args[i], fwd_decl_stream);
if (i < args.size() - 1) {
fwd_decl_stream << ", ";
}
}
fwd_decl_stream << ");\n";
}

void CodeGenCHost::PrintFuncPrefix(std::ostream& os) { // NOLINT(*)
os << "#ifdef __cplusplus\n"
<< "extern \"C\"\n"
<< "#endif\n"
<< "TVM_DLL int32_t";
}

void CodeGenCHost::PrintFinalReturn() { // NOLINT(*)
this->PrintIndent();
stream << "return 0;\n";
}

std::string CodeGenCHost::Finish() { // NOLINT(*)
std::string ret = decl_stream.str();
if (emit_fwd_func_decl_) {
ret += fwd_decl_stream.str();
}
ret += stream.str();
return ret;
}

void CodeGenCHost::PrintType(DataType t, std::ostream& os) { // NOLINT(*)
int lanes = t.lanes();
if (t.is_handle()) {
Expand Down Expand Up @@ -391,6 +424,7 @@ runtime::Module BuildCHost(IRModule mod, Target target) {
using tvm::runtime::Registry;
bool output_ssa = false;
bool emit_asserts = false;
bool emit_fwd_func_decl = true;

std::unordered_set<std::string> devices;
if (mod->GetAttr<Map<GlobalVar, String>>("device_contexts") != nullptr) {
Expand All @@ -402,7 +436,7 @@ runtime::Module BuildCHost(IRModule mod, Target target) {
}

CodeGenCHost cg;
cg.Init(output_ssa, emit_asserts, target->str(), devices);
cg.Init(output_ssa, emit_asserts, emit_fwd_func_decl, target->str(), devices);
cg.SetConstantsByteAlignment(target->GetAttr<Integer>("constants-byte-alignment").value_or(16));
PrimFunc aot_executor_fn;

Expand Down Expand Up @@ -438,7 +472,8 @@ runtime::Module BuildCHost(IRModule mod, Target target) {

// Add __tvm_main__
if (aot_executor_fn.defined()) {
cg.AddFunction(aot_executor_fn);
emit_fwd_func_decl = true;
cg.AddFunction(aot_executor_fn, emit_fwd_func_decl);
}

// NOTE: it's possible that kRuntime attr is not attached when the mod was built with tvm.build().
Expand Down
11 changes: 8 additions & 3 deletions src/target/source/codegen_c_host.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,12 @@ namespace codegen {
class CodeGenCHost : public CodeGenC {
public:
CodeGenCHost();
void Init(bool output_ssa, bool emit_asserts, std::string target_str,
void Init(bool output_ssa, bool emit_asserts, bool emit_fwd_func_decl, std::string target_str,
const std::unordered_set<std::string>& devices);

void InitGlobalContext();
void AddFunction(const PrimFunc& f);
void AddFunction(const PrimFunc& f, bool emit_fwd_func_decl = false);
std::string Finish() final;
/*!
* \brief Add functions from the (unordered) range to the current module in a deterministic
* order. This helps with debugging.
Expand All @@ -55,7 +56,7 @@ class CodeGenCHost : public CodeGenC {
void DefineModuleName();

void PrintType(DataType t, std::ostream& os) final; // NOLINT(*)
void PrintFuncPrefix() final; // NOLINT(*)
void PrintFuncPrefix(std::ostream& os) final; // NOLINT(*)
void PrintFinalReturn() final; // NOLINT(*)

// overload visitor functions
Expand All @@ -68,6 +69,8 @@ class CodeGenCHost : public CodeGenC {

void VisitStmt_(const AssertStmtNode* op) final; // NOLINT(*)

virtual void GenerateForwardFunctionDeclarations(String global_symbol,
const Array<PrimExpr>& args); // NOLINT(*)
Array<String> GetFunctionNames() { return function_names_; }

private:
Expand All @@ -87,6 +90,8 @@ class CodeGenCHost : public CodeGenC {
Array<String> function_names_;
/*! \brief whether to emit asserts in the resulting C code */
bool emit_asserts_;
/*! \brief whether to emit forwared function declarations in the resulting C code */
bool emit_fwd_func_decl_;

FunctionInfo GetFunctionInfo(const CallNode* op, bool has_resource_handle);
std::string GetPackedName(const CallNode* op);
Expand Down
2 changes: 1 addition & 1 deletion src/target/source/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ void CodeGenCUDA::Init(bool output_ssa) {
ICHECK_EQ(vid_global_barrier_state_, runtime::symbol::tvm_global_barrier_state);
}

void CodeGenCUDA::PrintFuncPrefix() { stream << "extern \"C\" __global__ void"; }
void CodeGenCUDA::PrintFuncPrefix(std::ostream& os) { os << "extern \"C\" __global__ void"; }

class ThreadIdxExtractor : public tir::StmtVisitor {
private:
Expand Down
2 changes: 1 addition & 1 deletion src/target/source/codegen_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class CodeGenCUDA final : public CodeGenC {
return (enable_fp16_ || enable_bf16_ || enable_int8_ || need_math_constants_h_ || need_mma_h_);
}
// override behavior
void PrintFuncPrefix() final;
void PrintFuncPrefix(std::ostream& os) final;
void PrintExtraAttrs(const PrimFunc& f) final;
void VisitStmt_(const ForNode* op) final;
void PrintStorageSync(const CallNode* op) final;
Expand Down
2 changes: 1 addition & 1 deletion src/target/source/codegen_opencl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ void CodeGenOpenCL::InitFuncState(const PrimFunc& f) {
}
}

void CodeGenOpenCL::PrintFuncPrefix() { stream << "__kernel void"; }
void CodeGenOpenCL::PrintFuncPrefix(std::ostream& os) { os << "__kernel void"; }

void CodeGenOpenCL::PreFunctionBody(const PrimFunc& f) {
for (Var arg : f->params) {
Expand Down
2 changes: 1 addition & 1 deletion src/target/source/codegen_opencl.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class CodeGenOpenCL final : public CodeGenC {

// override print thread tag.
void InitFuncState(const PrimFunc& f) final;
void PrintFuncPrefix() final; // NOLINT(*)
void PrintFuncPrefix(std::ostream& os) final; // NOLINT(*)
void PreFunctionBody(const PrimFunc& f) final; // NOLINT(*)
void BindThreadIndex(const IterVar& iv) final; // NOLINT(*)
void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*)
Expand Down
2 changes: 2 additions & 0 deletions src/target/source/codegen_source_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ class CodeGenSourceBase {
std::ostringstream decl_stream;
/*! \brief the stream to be printed */
std::ostringstream stream;
/*! \brief the forward declaration stream */
std::ostringstream fwd_decl_stream;
/*! \brief name of each variable */
std::unordered_map<const tir::VarNode*, std::string> var_idmap_;
/*! \brief NameSupply for allocation */
Expand Down
2 changes: 1 addition & 1 deletion src/target/source/codegen_vhls.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ void CodeGenVivadoHLS::PrintType(DataType t, std::ostream& os) {
}
}

void CodeGenVivadoHLS::PrintFuncPrefix() { stream << "extern \"C\" void"; }
void CodeGenVivadoHLS::PrintFuncPrefix(std::ostream& os) { os << "extern \"C\" void"; }

void CodeGenVivadoHLS::PreFunctionBody(const PrimFunc& f) {
for (size_t i = 0; i < f->params.size(); ++i) {
Expand Down
2 changes: 1 addition & 1 deletion src/target/source/codegen_vhls.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class CodeGenVivadoHLS final : public CodeGenC {
void Init(bool output_ssa);
void PrintType(DataType t, std::ostream& os);

void PrintFuncPrefix() final;
void PrintFuncPrefix(std::ostream& os) final;
void PreFunctionBody(const PrimFunc& f) final;
void VisitExpr_(const MinNode* op, std::ostream& os) final;
void VisitExpr_(const MaxNode* op, std::ostream& os) final;
Expand Down
2 changes: 1 addition & 1 deletion tests/python/relay/aot/corstone300.mk
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ ETHOSU_PATH=/opt/arm/ethosu
DRIVER_PATH=${ETHOSU_PATH}/core_driver
CMSIS_PATH=${ETHOSU_PATH}/cmsis
PLATFORM_PATH=${ETHOSU_PATH}/core_platform/targets/corstone-300
PKG_COMPILE_OPTS = -g -Wall -O2 -Wno-incompatible-pointer-types -Wno-format -mcpu=${MCPU}${MCPU_FLAGS} -mthumb -mfloat-abi=${MFLOAT_ABI} -std=gnu99
PKG_COMPILE_OPTS = -g -Wall -O2 -Wno-incompatible-pointer-types -Wno-format -Werror-implicit-function-declaration -mcpu=${MCPU}${MCPU_FLAGS} -mthumb -mfloat-abi=${MFLOAT_ABI} -std=gnu99
CMAKE = /opt/arm/cmake/bin/cmake
CC = arm-none-eabi-gcc
AR = arm-none-eabi-ar
Expand Down
Loading