Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions include/tvm/runtime/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,11 @@ class TVM_DLL ModuleNode : public Object {
* \return Possible source code when available.
*/
virtual std::string GetSource(const std::string& format = "");
/*!
* \brief Get the format of the module, when available.
* \return Possible format when available.
*/
virtual std::string GetFormat();
/*!
* \brief Get packed function from current module by name.
*
Expand Down
4 changes: 3 additions & 1 deletion python/tvm/micro/model_library_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,12 @@ def _populate_codegen_dir(mod, codegen_dir: str, module_name: str = None):

for dso_mod in dso_modules:
if dso_mod.type_key == "c":
assert dso_mod.format in ["c", "cc", "cpp"]
ext = dso_mod.format
index = mod_indices["src"]
mod_indices["src"] += 1
parent_dir = os.path.join(host_codegen_dir, "src")
file_name = os.path.join(parent_dir, f"{lib_name}{index}.c")
file_name = os.path.join(parent_dir, f"{lib_name}{index}.{ext}")
elif dso_mod.type_key == "llvm":
index = mod_indices["lib"]
mod_indices["lib"] += 1
Expand Down
22 changes: 20 additions & 2 deletions python/tvm/runtime/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,11 @@ def type_key(self):
"""Get type key of the module."""
return _ffi_api.ModuleGetTypeKey(self)

@property
def format(self):
"""Get the format of the module."""
return _ffi_api.ModuleGetFormat(self)

def get_source(self, fmt=""):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think Module can sometimes have multiple formats (e.g. CUDA modules do this I believe). what about leveraging this and modifying export_module to try with fmt="c" here and if not also try fmt="cc"?

"""Get source code from module, if available.

Expand Down Expand Up @@ -402,7 +407,12 @@ def export_library(self, file_name, fcompile=None, addons=None, workspace_dir=No
for index, module in enumerate(modules):
if fcompile is not None and hasattr(fcompile, "object_format"):
if module.type_key == "c":
object_format = "c"
assert module.format in [
"c",
"cc",
"cpp",
], "The module.format needs to be either c, cc or cpp"
object_format = module.format
has_c_module = True
else:
object_format = fcompile.object_format
Expand All @@ -411,7 +421,15 @@ def export_library(self, file_name, fcompile=None, addons=None, workspace_dir=No
object_format = "o"
else:
assert module.type_key == "c"
object_format = "c"
if len(module.format) > 0:
assert module.format in [
"c",
"cc",
"cpp",
], "The module.format needs to be either c, cc or cpp"
object_format = module.format
else:
object_format = "c"
if "cc" in kwargs:
if kwargs["cc"] == "nvcc":
object_format = "cu"
Expand Down
2 changes: 2 additions & 0 deletions src/relay/backend/contrib/ethosu/source_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ class EthosUModuleNode : public ModuleNode {

std::string GetSource(const std::string& format) final { return c_source; }

std::string GetFormat() { return "c"; }

Array<CompilationArtifact> GetArtifacts() { return compilation_artifacts_; }

/*!
Expand Down
9 changes: 9 additions & 0 deletions src/runtime/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,11 @@ const PackedFunc* ModuleNode::GetFuncFromEnv(const std::string& name) {
}
}

std::string ModuleNode::GetFormat() {
LOG(FATAL) << "Module[" << type_key() << "] does not support GetFormat";
return "";
}

bool RuntimeEnabled(const std::string& target) {
std::string f_name;
if (target == "cpu") {
Expand Down Expand Up @@ -179,6 +184,10 @@ TVM_REGISTER_GLOBAL("runtime.ModuleGetTypeKey").set_body_typed([](Module mod) {
return std::string(mod->type_key());
});

TVM_REGISTER_GLOBAL("runtime.ModuleGetFormat").set_body_typed([](Module mod) {
return mod->GetFormat();
});

TVM_REGISTER_GLOBAL("runtime.ModuleLoadFromFile").set_body_typed(Module::LoadFromFile);

TVM_REGISTER_GLOBAL("runtime.ModuleSaveToFile")
Expand Down
11 changes: 8 additions & 3 deletions src/target/source/source_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ class SourceModuleNode : public runtime::ModuleNode {

std::string GetSource(const std::string& format) final { return code_; }

std::string GetFormat() { return fmt_; }

protected:
std::string code_;
std::string fmt_;
Expand Down Expand Up @@ -101,10 +103,12 @@ class CSourceModuleNode : public runtime::ModuleNode {

std::string GetSource(const std::string& format) final { return code_; }

std::string GetFormat() { return fmt_; }

void SaveToFile(const std::string& file_name, const std::string& format) final {
std::string fmt = GetFileFormat(file_name, format);
std::string meta_file = GetMetaFilePath(file_name);
if (fmt == "c" || fmt == "cu") {
if (fmt == "c" || fmt == "cc" || fmt == "cpp" || fmt == "cu") {
ICHECK_NE(code_.length(), 0);
SaveBinaryToFile(file_name, code_);
} else {
Expand Down Expand Up @@ -142,14 +146,15 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode {

std::string GetSource(const std::string& format) final { return code_.str(); }

std::string GetFormat() { return fmt_; }
PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final {
return PackedFunc(nullptr);
}

void SaveToFile(const std::string& file_name, const std::string& format) final {
std::string fmt = GetFileFormat(file_name, format);
std::string meta_file = GetMetaFilePath(file_name);
if (fmt == "c") {
if (fmt == "c" || fmt == "cc" || fmt == "cpp") {
auto code_str = code_.str();
ICHECK_NE(code_str.length(), 0);
SaveBinaryToFile(file_name, code_str);
Expand Down Expand Up @@ -350,7 +355,7 @@ runtime::Module CreateCSourceCrtMetadataModule(const Array<runtime::Module>& mod
}
}
}
auto n = make_object<CSourceCrtMetadataModuleNode>(func_names, "cc", target, runtime, metadata);
auto n = make_object<CSourceCrtMetadataModuleNode>(func_names, "c", target, runtime, metadata);
auto csrc_metadata_module = runtime::Module(n);
for (const auto& mod : modules) {
csrc_metadata_module.Import(mod);
Expand Down
12 changes: 7 additions & 5 deletions tests/python/relay/aot/corstone300.mk
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,10 @@ QUIET ?= @
$(endif)

CRT_SRCS = $(shell find $(CRT_ROOT))
CODEGEN_SRCS = $(shell find $(abspath $(CODEGEN_ROOT)/host/src/*.c))
CODEGEN_OBJS = $(subst .c,.o,$(CODEGEN_SRCS))
C_CODEGEN_SRCS = $(shell find $(abspath $(CODEGEN_ROOT)/host/src/*.c))
CC_CODEGEN_SRCS = $(shell find $(abspath $(CODEGEN_ROOT)/host/src/*.cc))
C_CODEGEN_OBJS = $(subst .c,.o,$(C_CODEGEN_SRCS))
CC_CODEGEN_OBJS = $(subst .cc,.o,$(CC_CODEGEN_SRCS))
CMSIS_STARTUP_SRCS = $(shell find ${CMSIS_PATH}/Device/ARM/${ARM_CPU}/Source/*.c)
UART_SRCS = $(shell find ${PLATFORM_PATH}/*.c)

Expand All @@ -90,9 +92,9 @@ $(build_dir)/tvm_ethosu_runtime.o: $(TVM_ROOT)/src/runtime/contrib/ethosu/bare_m
$(QUIET)mkdir -p $(@D)
$(QUIET)$(CC) -c $(PKG_CFLAGS) -o $@ $^

$(build_dir)/libcodegen.a: $(CODEGEN_SRCS)
$(QUIET)cd $(abspath $(CODEGEN_ROOT)/host/src) && $(CC) -c $(PKG_CFLAGS) $(CODEGEN_SRCS)
$(QUIET)$(AR) -cr $(abspath $(build_dir)/libcodegen.a) $(CODEGEN_OBJS)
$(build_dir)/libcodegen.a: $(C_CODEGEN_SRCS) $(CC_CODEGEN_SRCS)
$(QUIET)cd $(abspath $(CODEGEN_ROOT)/host/src) && $(CC) -c $(PKG_CFLAGS) $(C_CODEGEN_SRCS) $(CC_CODEGEN_SRCS)
$(QUIET)$(AR) -cr $(abspath $(build_dir)/libcodegen.a) $(C_CODEGEN_OBJS) $(CC_CODEGEN_OBJS)
$(QUIET)$(RANLIB) $(abspath $(build_dir)/libcodegen.a)

${build_dir}/libcmsis_startup.a: $(CMSIS_STARTUP_SRCS)
Expand Down
13 changes: 10 additions & 3 deletions tests/python/relay/aot/default.mk
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ ENABLE_TVM_PLATFORM_ABORT_BACKTRACE = 0
DMLC_CORE=$(TVM_ROOT)/3rdparty/dmlc-core
PKG_COMPILE_OPTS = -g
CC = gcc
#CC = g++
AR = ar
RANLIB = ranlib
CC_OPTS = CC=$(CC) AR=$(AR) RANLIB=$(RANLIB)
Expand All @@ -39,17 +40,23 @@ $(endif)

aot_test_runner: $(build_dir)/aot_test_runner

source_libs= $(wildcard $(build_dir)/../codegen/host/src/*.c)
lib_objs =$(source_libs:.c=.o)
c_source_libs= $(wildcard $(build_dir)/../codegen/host/src/*.c)
cc_source_libs= $(wildcard $(build_dir)/../codegen/host/src/*.cc)
c_lib_objs =$(c_source_libs:.c=.o)
cc_lib_objs =$(cc_source_libs:.cc=.o)

$(build_dir)/aot_test_runner: $(build_dir)/test.c $(source_libs) $(build_dir)/stack_allocator.o $(build_dir)/crt_backend_api.o
$(build_dir)/aot_test_runner: $(build_dir)/test.c $(c_source_libs) $(cc_source_libs) $(build_dir)/stack_allocator.o $(build_dir)/crt_backend_api.o
$(QUIET)mkdir -p $(@D)
$(QUIET)$(CC) $(CFLAGS) $(PKG_CFLAGS) -o $@ $^ $(PKG_LDFLAGS) $(BACKTRACE_LDFLAGS) $(BACKTRACE_CFLAGS) -lm

$(build_dir)/%.o: $(build_dir)/../codegen/host/src/%.c
$(QUIET)mkdir -p $(@D)
$(QUIET)$(CC) $(CFLAGS) -c $(PKG_CFLAGS) -o $@ $^ $(BACKTRACE_CFLAGS)

$(build_dir)/%.o: $(build_dir)/../codegen/host/src/%.cc
$(QUIET)mkdir -p $(@D)
$(QUIET)$(CC) $(CFLAGS) -c $(PKG_CFLAGS) -o $@ $^ $(BACKTRACE_CFLAGS)

$(build_dir)/stack_allocator.o: $(STANDALONE_CRT_DIR)/src/runtime/crt/memory/stack_allocator.c
$(QUIET)mkdir -p $(@D)
$(QUIET)$(CC) $(CFLAGS) -c $(PKG_CFLAGS) -o $@ $^ $(BACKTRACE_CFLAGS)
Expand Down