Skip to content
Merged
1 change: 1 addition & 0 deletions python/tvm/micro/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from .debugger import GdbRemoteDebugger
from .micro_library import MicroLibrary
from .micro_binary import MicroBinary
from .model_library_format import export_model_library_format, UnsupportedInModelLibraryFormatError
from .session import (
create_local_graph_runtime,
create_local_debug_runtime,
Expand Down
171 changes: 171 additions & 0 deletions python/tvm/micro/model_library_format.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""Defines functions for exporting to Model Library Format."""

import datetime
import json
import os
import re
import tarfile

from ..contrib import utils
from ..relay.backend import graph_runtime_factory
from ..relay import param_dict


class UnsupportedInModelLibraryFormatError(Exception):
"""Raised when export_model_library_format does not support the given Module tree."""


def _populate_codegen_dir(mod, codegen_dir: str):
"""Populate the codegen sub-directory as part of a Model Library Format export.

Parameters
----------
mod : tvm.runtime.Module
Module which should be written to codegen_dir.
codegen_dir : str
Path to the codegen directory on disk.
"""
dso_modules = mod._collect_dso_modules()
dso_module_handles = [m.handle.value for m in dso_modules]
non_dso_modules = mod._collect_from_import_tree(lambda m: m not in dso_modules)
if non_dso_modules:
raise UnsupportedInModelLibraryFormatError(
f"Don't know how to export non-c or non-llvm modules; found: {non_dso_modules!r}"
)

mod_indices = {"lib": 0, "src": 0}
host_codegen_dir = os.path.join(codegen_dir, "host")
for dso_mod in dso_modules:
if dso_mod.type_key == "c":
index = mod_indices["src"]
mod_indices["src"] += 1
parent_dir = os.path.join(host_codegen_dir, "src")
file_name = os.path.join(parent_dir, f"lib{index}.c")
elif dso_mod.type_key == "llvm":
index = mod_indices["lib"]
mod_indices["lib"] += 1
parent_dir = os.path.join(host_codegen_dir, "lib")
file_name = os.path.join(parent_dir, f"lib{index}.o")
else:
assert (
False
), f"do not expect module with type_key={mod.type_key} from _collect_dso_modules"

if not os.path.exists(parent_dir):
os.makedirs(parent_dir)
dso_mod.save(file_name)


def _build_memory_map(graph_json):
"""Build a simpler memory map from graph JSON.

Parameters
----------
graph_json : str
String representation of the graph_json created from tvm.relay.build().

Returns
-------
list :
A list with one entry per storage id describing that memory.
"""
graph = json.loads(graph_json)

seen_storage_ids = set()
memory_map = []
for node_id, storage_id in enumerate(graph["attrs"]["storage_id"][1]):
if storage_id in seen_storage_ids:
continue

seen_storage_ids.add(storage_id)
num_elements = 1
for dim in graph["attrs"]["shape"][1][storage_id]:
num_elements *= dim

dltype = graph["attrs"]["dltype"][1][storage_id]
m = re.match(r"^[a-zA-Z]+([0-9]+)$", dltype)
assert m, f"Exported graph contains unknown dltype {dltype}"

elem_bits = int(m.group(1))

map_entry = {
"storage_id": storage_id,
"size_bytes": (num_elements * elem_bits + 7) // 8,
}
if node_id in graph["arg_nodes"]:
map_entry["input_binding"] = graph["nodes"][node_id]["name"]

memory_map.append(map_entry)

return memory_map


def export_model_library_format(mod: graph_runtime_factory.GraphRuntimeFactoryModule, file_name):
"""Export the build artifact in Model Library Format.

This function creates a .tar archive containing the build artifacts in a standardized
layout. It's intended to allow downstream automation to build TVM artifacts against the C
runtime.

Parameters
----------
mod : tvm.relay.backend.graph_runtime_factory.GraphRuntimeFactoryModule
The return value of tvm.relay.build, which will be exported into Model Library Format.
file_name : str
Path to the .tar archive to generate.
"""
tempdir = utils.tempdir()
metadata = {
"version": 1,
"model_name": mod.libmod_name,
"export_datetime": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%SZ"),
"memory": _build_memory_map(mod.graph_json),
"target": {int(k): str(v) for k, v in mod.target.items()},
"runtimes": ["graph"],
}
with open(tempdir.relpath("metadata.json"), "w") as json_f:
json.dump(metadata, json_f, indent=2, sort_keys=True)

codegen_dir_path = tempdir.relpath("codegen")
os.mkdir(codegen_dir_path)
_populate_codegen_dir(mod.lib, codegen_dir_path)

parameters_dir_path = tempdir.relpath("parameters")
os.mkdir(parameters_dir_path)
param_filename = os.path.join(parameters_dir_path, f"{mod.libmod_name}.params")
with open(param_filename, "wb") as f:
f.write(param_dict.save_param_dict(mod.params))

with open(tempdir.relpath("relay.txt"), "w") as f:
f.write(str(mod.ir_mod))

graph_config_dir_path = tempdir.relpath(os.path.join("runtime-config", "graph"))
os.makedirs(graph_config_dir_path)
with open(os.path.join(graph_config_dir_path, "graph.json"), "w") as f:
f.write(mod.graph_json)

with tarfile.open(file_name, "w") as tar_f:

def reset(tarinfo):
tarinfo.uid = tarinfo.gid = 0
tarinfo.uname = tarinfo.gname = "root"
return tarinfo

tar_f.add(tempdir.temp_dir, arcname=".", filter=reset)
12 changes: 8 additions & 4 deletions python/tvm/relay/backend/graph_runtime_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
# under the License.
"""Graph runtime factory."""
import warnings
from tvm._ffi.base import string_types
from tvm._ffi.registry import get_global_func
from tvm.runtime import ndarray
from ..._ffi.base import string_types
from ..._ffi.registry import get_global_func
from ...runtime import ndarray


class GraphRuntimeFactoryModule:
Expand All @@ -31,6 +31,8 @@ class GraphRuntimeFactoryModule:
The graph to be deployed in json format output by graph compiler.
The graph can contain operator(tvm_op) that points to the name of
PackedFunc in the libmod.
target : tvm.Target
The Target used to build this module.
libmod : tvm.Module
The module of the corresponding function
libmod_name: str
Expand All @@ -39,13 +41,15 @@ class GraphRuntimeFactoryModule:
The parameters of module
"""

def __init__(self, graph_json_str, libmod, libmod_name, params):
def __init__(self, ir_mod, target, graph_json_str, libmod, libmod_name, params):
assert isinstance(graph_json_str, string_types)
fcreate = get_global_func("tvm.graph_runtime_factory.create")
args = []
for k, v in params.items():
args.append(k)
args.append(ndarray.array(v))
self.ir_mod = ir_mod
self.target = target
self.module = fcreate(graph_json_str, libmod, libmod_name, *args)
self.graph_json = graph_json_str
self.lib = libmod
Expand Down
20 changes: 11 additions & 9 deletions python/tvm/relay/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,14 +208,14 @@ def _build_module_no_factory(mod, target=None, target_host=None, params=None, mo
return build(mod, target, target_host, params, mod_name).module


def build(mod, target=None, target_host=None, params=None, mod_name="default"):
def build(ir_mod, target=None, target_host=None, params=None, mod_name="default"):
# fmt: off
# pylint: disable=line-too-long
"""Helper function that builds a Relay function to run on TVM graph runtime.

Parameters
----------
mod : :py:class:`~tvm.IRModule`
ir_mod : :py:class:`~tvm.IRModule`
The IR module to build. Using relay.Function is deprecated.

target : str, :any:`tvm.target.Target`, or dict of str(i.e. device/context name) to str/tvm.target.Target, optional
Expand Down Expand Up @@ -251,13 +251,13 @@ def build(mod, target=None, target_host=None, params=None, mod_name="default"):
"""
# pylint: enable=line-too-long
# fmt: on
if not isinstance(mod, (IRModule, _function.Function)):
if not isinstance(ir_mod, (IRModule, _function.Function)):
raise ValueError("Type of input parameter mod must be tvm.IRModule")

if isinstance(mod, _function.Function):
if isinstance(ir_mod, _function.Function):
if params:
mod = bind_params_by_name(mod, params)
mod = IRModule.from_expr(mod)
ir_mod = bind_params_by_name(ir_mod, params)
ir_mod = IRModule.from_expr(ir_mod)
warnings.warn(
"Please use input parameter mod (tvm.IRModule) "
"instead of deprecated parameter mod (tvm.relay.function.Function)",
Expand All @@ -280,9 +280,11 @@ def build(mod, target=None, target_host=None, params=None, mod_name="default"):

with tophub_context:
bld_mod = BuildModule()
graph_json, mod, params = bld_mod.build(mod, target, target_host, params)
mod = _graph_runtime_factory.GraphRuntimeFactoryModule(graph_json, mod, mod_name, params)
return mod
graph_json, runtime_mod, params = bld_mod.build(ir_mod, target, target_host, params)
runtime_mod = _graph_runtime_factory.GraphRuntimeFactoryModule(
ir_mod, target, graph_json, runtime_mod, mod_name, params
)
return runtime_mod


def optimize(mod, target=None, params=None):
Expand Down
26 changes: 21 additions & 5 deletions python/tvm/runtime/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,9 @@ def __getitem__(self, name):
raise ValueError("Can only take string as function name")
return self.get_function(name)

def __eq__(self, other):
return self.handle.value == other.handle.value

def __call__(self, *args):
if self._entry:
return self._entry(*args)
Expand Down Expand Up @@ -233,24 +236,37 @@ def evaluator(*args):
except NameError:
raise NameError("time_evaluate is only supported when RPC is enabled")

def _collect_dso_modules(self):
"""Helper function to collect dso modules, then return it."""
def _collect_from_import_tree(self, filter_func):
"""Helper function to collect modules from the tree matching a filter_func, then return it.

Parameters
----------
filter_func : Callable[[Module], bool]
A function which is invoked for each Module discovered in the import tree (including
self).

Returns
-------
list[Module] :
A list of matching Module.
"""
visited, stack, dso_modules = set(), [], []
# append root module
visited.add(self)
stack.append(self)
while stack:
module = stack.pop()
if module._dso_exportable():
if filter_func(module):
dso_modules.append(module)
for m in module.imported_modules:
if m not in visited:
visited.add(m)
stack.append(m)
return dso_modules

def _dso_exportable(self):
return self.type_key == "llvm" or self.type_key == "c"
def _collect_dso_modules(self):
is_dso_exportable = lambda m: (m.type_key == "llvm" or m.type_key == "c")
return self._collect_from_import_tree(is_dso_exportable)

def export_library(self, file_name, fcompile=None, addons=None, workspace_dir=None, **kwargs):
"""Export the module and its imported device code one library.
Expand Down
3 changes: 2 additions & 1 deletion src/runtime/graph/graph_runtime_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,8 @@ TVM_REGISTER_GLOBAL("tvm.graph_runtime_factory.create").set_body([](TVMArgs args
"graph_runtime_factory.create needs at least 3, "
"but it has "
<< args.num_args;
// The argument order is graph_json, module, module_name, params.
// The argument order is graph_json, module, module_name, param0_name, param0_tensor,
// [param1_name, param1_tensor], ...
ICHECK_EQ((args.size() - 3) % 2, 0);
std::unordered_map<std::string, tvm::runtime::NDArray> params;
for (size_t i = 3; i < static_cast<size_t>(args.size()); i += 2) {
Expand Down
Loading