diff --git a/python/tvm/contrib/utils.py b/python/tvm/contrib/utils.py index 6451896c6bd1..68c6b3d5bf6b 100644 --- a/python/tvm/contrib/utils.py +++ b/python/tvm/contrib/utils.py @@ -19,6 +19,7 @@ import contextlib import datetime import os +import pathlib import tempfile import threading import shutil @@ -119,6 +120,18 @@ def remove(self): self.TEMPDIRS.remove(self.temp_dir) self.temp_dir = None + @property + def path(self): + return pathlib.Path(self.temp_dir) + + def __div__(self, other): + if not isinstance(other, (str, pathlib.Path)): + raise TypeError( + "TempDirectory / operator: must supply str or pathlib.Path; got %r" % (other,) + ) + + return self.path / other + def __del__(self): temp_dirs = getattr(self, "TEMPDIRS", None) if temp_dirs is None: diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index a9e07299f6dd..0533898ded35 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -24,6 +24,7 @@ import tvm.tir +from tvm.runtime import Module from tvm.runtime import ndarray from tvm.ir import container from tvm.ir import CallingConv @@ -372,12 +373,32 @@ def build( create_csource_crt_metadata_module = tvm._ffi.get_global_func( "runtime.CreateCSourceCrtMetadataModule" ) - return create_csource_crt_metadata_module([rt_mod_host], target_host) + to_return = create_csource_crt_metadata_module([rt_mod_host], target_host) - if target_host.kind.name == "llvm": + elif target_host.kind.name == "llvm": create_llvm_crt_metadata_module = tvm._ffi.get_global_func( "runtime.CreateLLVMCrtMetadataModule" ) - return create_llvm_crt_metadata_module([rt_mod_host], target_host) + to_return = create_llvm_crt_metadata_module([rt_mod_host], target_host) + else: + to_return = rt_mod_host + + return OperatorModule.from_module(to_return, ir_module_by_target=target_input_mod, name=name) + + +class OperatorModule(Module): + """Wraps the Module returned by tvm.build() and captures additional outputs of that function.""" + + @classmethod + def from_module(cls, mod, **kwargs): + # NOTE(areusch): It is generally unsafe to continue using `mod` from this point forward. + # If an exception occurs in cls.__init__, handle will be deleted. For this reason, + # set mod.handle to None. + handle = mod.handle + mod.handle = None + return cls(handle, **kwargs) - return rt_mod_host + def __init__(self, handle, ir_module_by_target=None, name=None): + super(OperatorModule, self).__init__(handle) + self.ir_module_by_target = ir_module_by_target + self.name = name diff --git a/python/tvm/micro/model_library_format.py b/python/tvm/micro/model_library_format.py index 7062b20e0d54..87c067051f82 100644 --- a/python/tvm/micro/model_library_format.py +++ b/python/tvm/micro/model_library_format.py @@ -20,12 +20,18 @@ import datetime import json import os +import pathlib import re import tarfile +import typing +from .._ffi import get_global_func from ..contrib import utils +from ..driver import build_module +from ..runtime import ndarray as _nd from ..relay.backend import executor_factory from ..relay import param_dict +from ..tir import expr # This should be kept identical to runtime::symbol::tvm_module_main MAIN_FUNC_NAME_STR = "__tvm_main__" @@ -207,67 +213,207 @@ def _build_function_memory_map(function_metadata): return ret -def export_model_library_format(mod: executor_factory.ExecutorFactoryModule, file_name): - """Export the build artifact in Model Library Format. +def _make_tar(source_dir, tar_file_path): + """Build a tar file from source_dir.""" + with tarfile.open(tar_file_path, "w") as tar_f: - This function creates a .tar archive containing the build artifacts in a standardized - layout. It's intended to allow downstream automation to build TVM artifacts against the C - runtime. + def reset(tarinfo): + tarinfo.uid = tarinfo.gid = 0 + tarinfo.uname = tarinfo.gname = "root" + return tarinfo + + tar_f.add(str(source_dir), arcname=".", filter=reset) + + +_GENERATED_VERSION = 4 + + +def _export_graph_model_library_format( + mod: executor_factory.ExecutorFactoryModule, tempdir: pathlib.Path +): + """Export a tvm.relay.build artifact in Model Library Format. Parameters ---------- mod : tvm.relay.backend.executor_factory.ExecutorFactoryModule The return value of tvm.relay.build, which will be exported into Model Library Format. - file_name : str - Path to the .tar archive to generate. - - Returns - ------- - file_name : str - The path to the generated .tar archive. + tempdir : pathlib.Path + Temporary directory to populate with Model Library Format contents. """ - tempdir = utils.tempdir() is_aot = isinstance(mod, executor_factory.AOTExecutorFactoryModule) runtime = ["aot"] if is_aot else ["graph"] metadata = { - "version": 3, + "version": _GENERATED_VERSION, "model_name": mod.libmod_name, "export_datetime": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%SZ"), "memory": _build_memory_map(mod), "target": {int(k): str(v) for k, v in mod.target.items()}, "runtimes": runtime, + "style": "full-model", } - with open(tempdir.relpath("metadata.json"), "w") as json_f: + with open(tempdir / "metadata.json", "w") as json_f: json.dump(metadata, json_f, indent=2, sort_keys=True) - codegen_dir_path = tempdir.relpath("codegen") - os.mkdir(codegen_dir_path) - _populate_codegen_dir(mod.lib, codegen_dir_path, mod.libmod_name) + codegen_dir = tempdir / "codegen" + codegen_dir.mkdir() + _populate_codegen_dir(mod.lib, codegen_dir, mod.libmod_name) - parameters_dir_path = tempdir.relpath("parameters") - os.mkdir(parameters_dir_path) - param_filename = os.path.join(parameters_dir_path, f"{mod.libmod_name}.params") + parameters_dir = tempdir / "parameters" + parameters_dir.mkdir() + param_filename = parameters_dir / f"{mod.libmod_name}.params" with open(param_filename, "wb") as f: f.write(param_dict.save_param_dict(mod.params)) - with open(tempdir.relpath("relay.txt"), "w") as f: + src_dir = tempdir / "src" + src_dir.mkdir() + with open(src_dir / "relay.txt", "w") as f: f.write(str(mod.ir_mod)) if not is_aot: - graph_config_dir_path = tempdir.relpath(os.path.join("runtime-config", "graph")) - os.makedirs(graph_config_dir_path) - with open(os.path.join(graph_config_dir_path, "graph.json"), "w") as f: + graph_config_dir = tempdir / "runtime-config" / "graph" + graph_config_dir.mkdir(parents=True) + with open(graph_config_dir / "graph.json", "w") as f: f.write(mod.get_executor_config()) - with tarfile.open(file_name, "w") as tar_f: - def reset(tarinfo): - tarinfo.uid = tarinfo.gid = 0 - tarinfo.uname = tarinfo.gname = "root" - return tarinfo +class NonStaticShapeError(Exception): + """Raised when a shape has elements other than IntImm.""" + + +def _shape_to_size(shape, dtype): + bits_per_item = int( + re.match(r"((float)|(int))(?P[0-9]+)", dtype).group("width_bits") + ) + assert bits_per_item is not None, f"don't know how to compute size of type {dtype}" + total_bits = bits_per_item + for s in shape: + total_bits *= s + + return (total_bits + 7) // 8 + + +def _write_tir_and_build_operator_memory_map(src_dir, targets, ir_module_by_target): + def _eval_shape(param_name, buffer_shape): + shape = [] + for x in buffer_shape: + if not isinstance(x, expr.IntImm): + raise NonStaticShapeError( + f"Parameter {param_name} has shape with non-IntImm elements: {buffer_shape}" + ) + shape.append(x.value) + return shape + + memory_map = {} + for target_device_type, target in targets.items(): + ir_mod = ir_module_by_target[target] + printer = get_global_func("tir.ModelLibraryFormatPrinter")(False, None, False) + with open(src_dir / f"tir-{target_device_type}.txt", "w") as f: + f.write(printer["print"](ir_mod)) + + for v in ir_mod.get_global_vars(): + map_entry = [] + for p, b in ir_mod[v.name_hint].buffer_map.items(): + shape = _eval_shape(p.name, b.shape) + buffer_size_bytes = _shape_to_size(shape, str(b.dtype)) + # NOTE: cannot tell what is an input or output at this point. + map_entry.append( + { + "size_bytes": buffer_size_bytes, + "shape": [int(x) for x in b.shape], + "dtype": b.dtype, + "input_binding": printer["get_var_name"](p), + } + ) + memory_map[v.name_hint] = map_entry + + return memory_map + + +def _export_operator_model_library_format(mod: build_module.OperatorModule, tempdir): + """Export the result of tvm.build() in Model Library Format. + + Parameters + ---------- + mod : runtime.Module + The Module returned from tvm.build(). + args : list of Buffer or Tensor or Var, optional + The args supplied to tvm.build(). + file_name : str + Path to the .tar archive to generate. + """ + targets = {} + for target in mod.ir_module_by_target.keys(): + if str(target.kind) not in ("llvm", "c"): + raise UnsupportedInModelLibraryFormatError( + f"Operator has non-DSO-exportable target {target!s}, which is not yet supported in " + "Model Library Format" + ) + + targets[int(_nd.device(str(target)).device_type)] = target + + src_dir = tempdir / "src" + src_dir.mkdir() + memory_map = _write_tir_and_build_operator_memory_map(src_dir, targets, mod.ir_module_by_target) + + metadata = { + "version": _GENERATED_VERSION, + "model_name": mod.name, + "export_datetime": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%SZ"), + "memory": memory_map, + "target": {k: str(v) for k, v in targets.items()}, + "runtimes": [], + "style": "operator", + } + with open(tempdir / "metadata.json", "w") as metadata_f: + json.dump(metadata, metadata_f) + + codegen_dir = tempdir / "codegen" + codegen_dir.mkdir() + _populate_codegen_dir(mod, codegen_dir) + + +ExportableModule = typing.Union[ + build_module.OperatorModule, + executor_factory.AOTExecutorFactoryModule, + executor_factory.GraphExecutorFactoryModule, +] + + +def export_model_library_format(mod: ExportableModule, file_name: typing.Union[str, pathlib.Path]): + """Export the build artifact in Model Library Format. + + This function creates a .tar archive containing the build artifacts in a standardized + layout. It's intended to allow downstream automation to build TVM artifacts against the C + runtime. + + Parameters + ---------- + mod : ExportableModule + The return value of tvm.build or tvm.relay.build. + file_name : str + Path to the .tar archive to generate. + + Returns + ------- + file_name : str + The path to the generated .tar archive. + """ + file_name = pathlib.Path(file_name) + + tempdir = utils.tempdir() - tar_f.add(tempdir.temp_dir, arcname=".", filter=reset) + if isinstance(mod, build_module.OperatorModule): + _export_operator_model_library_format(mod, tempdir.path) + elif isinstance( + mod, + (executor_factory.AOTExecutorFactoryModule, executor_factory.GraphExecutorFactoryModule), + ): + _export_graph_model_library_format(mod, tempdir.path) + else: + raise NotImplementedError(f"Don't know how to export module of type {mod.__class__!r}") + + _make_tar(tempdir.path, file_name) return file_name diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index ed722643ff70..aa826aee57a1 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -40,7 +40,23 @@ from .backend.vm import VMExecutor -def _update_target(target): +def build_target_by_device_type_map(target): + """Build a map from DLDevice device_type to a Target used with that device. + + At runtime, TVM assigns target code to DLDevices by determining a device_type for each Target. + This function handles this process at compile time and, as a side effect, validates that exactly + one target maps to one device_type. + + Parameters + ---------- + target : Target or str or dict + If a Target or str: assumes that exactly one device type is present in the model. + If a dict: keys are tvm.ndarray.device, values are the targets used for each device. + + Returns + ------- + + """ target = target if target else Target.current() if target is None: raise ValueError("Target is not set in env or passed as argument.") @@ -132,7 +148,7 @@ def build( params : dict The parameters of the final graph. """ - target = _update_target(target) + target = build_target_by_device_type_map(target) target, target_host = Target.check_and_update_host_consist( target, target_host, target_is_dict_key=False ) @@ -187,7 +203,7 @@ def optimize(self, mod, target=None, params=None): params : dict The parameters of the final graph. """ - target = _update_target(target) + target = build_target_by_device_type_map(target) # Setup the params. if params: @@ -316,7 +332,7 @@ def build(ir_mod, target=None, target_host=None, params=None, mod_name="default" "instead of deprecated parameter mod (tvm.relay.function.Function)", DeprecationWarning, ) - target = _update_target(target) + target = build_target_by_device_type_map(target) if isinstance(target_host, (str, Target)): target_host = Target(target_host) elif target_host: @@ -395,7 +411,7 @@ def optimize(mod, target=None, params=None): DeprecationWarning, ) - target = _update_target(target) + target = build_target_by_device_type_map(target) # If current dispatch context is fallback context (the default root context), # then load pre-tuned parameters from TopHub diff --git a/src/printer/model_library_format_printer.cc b/src/printer/model_library_format_printer.cc new file mode 100644 index 000000000000..17ba84e68df4 --- /dev/null +++ b/src/printer/model_library_format_printer.cc @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include + +#include + +#include "text_printer.h" + +namespace tvm { +namespace printer { + +class ModelLibraryFormatPrinter : public ::tvm::runtime::ModuleNode { + public: + ModelLibraryFormatPrinter(bool show_meta_data, + const runtime::TypedPackedFunc& annotate, + bool show_warning) + : text_printer_{show_meta_data, annotate, show_warning} {} + + const char* type_key() const override { return "model_library_format_printer"; } + + std::string Print(const ObjectRef& node) { + Doc doc; + doc << text_printer_.PrintFinal(node); + return doc.str(); + } + + TVMRetValue GetVarName(tir::Var var) { + TVMRetValue rv; + std::string var_name; + if (text_printer_.GetVarName(var, &var_name)) { + rv = var_name; + } + + return rv; + } + + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) override { + if (name == "print") { + return TypedPackedFunc( + [sptr_to_self, this](ObjectRef node) { return Print(node); }); + } else if (name == "get_var_name") { + return TypedPackedFunc( + [sptr_to_self, this](tir::Var var) { return GetVarName(var); }); + } else { + return PackedFunc(); + } + } + + private: + TextPrinter text_printer_; +}; + +TVM_REGISTER_GLOBAL("tir.ModelLibraryFormatPrinter") + .set_body_typed([](bool show_meta_data, + const runtime::TypedPackedFunc& annotate, + bool show_warning) { + return ObjectRef( + make_object(show_meta_data, annotate, show_warning)); + }); + +} // namespace printer +} // namespace tvm diff --git a/src/printer/text_printer.h b/src/printer/text_printer.h index 7a529cc0b914..0332a2d539d2 100644 --- a/src/printer/text_printer.h +++ b/src/printer/text_printer.h @@ -34,6 +34,7 @@ #include #include #include +#include #include #include @@ -256,6 +257,13 @@ class TIRTextPrinter : public StmtFunctor, /*! \brief Print the node */ Doc Print(const ObjectRef& node); + /*! \brief Place into `s` the name used in the preceding Print call for `v`. + * \param v Var instance to check. Must point to a VarNode visited by Print. + * \param s String to receive the name. + * \return true when a name re-mapping was found. + */ + bool GetVarName(::tvm::tir::Var v, std::string* s); + private: /*! \brief whether show meta data */ bool show_meta_; @@ -394,6 +402,8 @@ class TextPrinter { /*! \brief TIR Text Printer */ tir::TIRTextPrinter tir_text_printer_; + bool GetVarName(::tvm::tir::Var v, std::string* s) { return tir_text_printer_.GetVarName(v, s); } + Doc PrintFinal(const ObjectRef& node) { Doc doc; if (node->IsInstance()) { diff --git a/src/printer/tir_text_printer.cc b/src/printer/tir_text_printer.cc index 04c5ea1cdf99..0fefb0515e49 100644 --- a/src/printer/tir_text_printer.cc +++ b/src/printer/tir_text_printer.cc @@ -734,5 +734,15 @@ Doc TIRTextPrinter::PrintBody(const Stmt& body, bool indent) { return doc; } +bool TIRTextPrinter::GetVarName(Var v, std::string* s) { + auto it = memo_var_.find(v); + if (it == memo_var_.end()) { + return false; + } + + *s = it->second.str(); + return true; +} + } // namespace tir } // namespace tvm diff --git a/tests/python/unittest/test_micro_model_library_format.py b/tests/python/unittest/test_micro_model_library_format.py index 2922a3adf48b..246c0336a001 100644 --- a/tests/python/unittest/test_micro_model_library_format.py +++ b/tests/python/unittest/test_micro_model_library_format.py @@ -32,6 +32,62 @@ from tvm.contrib import utils +@tvm.testing.requires_micro +def test_export_operator_model_library_format(): + import tvm.micro as micro + + target = tvm.target.target.micro("host") + with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): + A = tvm.te.placeholder((2,), dtype="int8") + B = tvm.te.placeholder((1,), dtype="int8") + C = tvm.te.compute(A.shape, lambda i: A[i] + B[0], name="C") + sched = tvm.te.create_schedule(C.op) + mod = tvm.build(sched, [A, B, C], tvm.target.Target(target, target), name="add") + + temp_dir = utils.tempdir() + mlf_tar_path = temp_dir.relpath("lib.tar") + micro.export_model_library_format(mod, mlf_tar_path) + + tf = tarfile.open(mlf_tar_path) + + extract_dir = temp_dir.relpath("extract") + os.mkdir(extract_dir) + tf.extractall(extract_dir) + + with open(os.path.join(extract_dir, "metadata.json")) as json_f: + metadata = json.load(json_f) + assert metadata["version"] == 4 + assert metadata["model_name"] == "add" + export_datetime = datetime.datetime.strptime( + metadata["export_datetime"], "%Y-%m-%d %H:%M:%SZ" + ) + assert (datetime.datetime.now() - export_datetime) < datetime.timedelta(seconds=60 * 5) + assert metadata["target"] == {"1": str(target)} + + assert metadata["memory"]["add"][0]["dtype"] == "int8" + assert metadata["memory"]["add"][0]["shape"] == [2] + assert metadata["memory"]["add"][0]["size_bytes"] == 2 + + assert metadata["memory"]["add"][1]["dtype"] == "int8" + assert metadata["memory"]["add"][1]["shape"] == [1] + assert metadata["memory"]["add"][1]["size_bytes"] == 1 + + assert metadata["memory"]["add"][2]["dtype"] == "int8" + assert metadata["memory"]["add"][2]["shape"] == [2] + assert metadata["memory"]["add"][2]["size_bytes"] == 2 + + assert os.path.exists(os.path.join(extract_dir, "codegen", "host", "src", "lib0.c")) + assert os.path.exists(os.path.join(extract_dir, "codegen", "host", "src", "lib1.c")) + + assert ( + len(mod.ir_module_by_target) == 1 + ), f"expect 1 ir_model_by_target: {ir_module_by_target!r}" + for target, ir_mod in mod.ir_module_by_target.items(): + assert int(tvm.runtime.ndarray.device(str(target)).device_type) == 1 + with open(os.path.join(extract_dir, "src", "tir-1.txt")) as tir_f: + assert tir_f.read() == str(ir_mod) + + def validate_graph_json(extract_dir, factory): with open(os.path.join(extract_dir, "runtime-config", "graph", "graph.json")) as graph_f: graph_json = graph_f.read() @@ -85,7 +141,7 @@ def @main(%a : Tensor[(1, 2), uint8], %b : Tensor[(1, 2), float32], %c : Tensor[ with open(os.path.join(extract_dir, "metadata.json")) as json_f: metadata = json.load(json_f) - assert metadata["version"] == 3 + assert metadata["version"] == 4 assert metadata["model_name"] == "add" export_datetime = datetime.datetime.strptime( metadata["export_datetime"], "%Y-%m-%d %H:%M:%SZ" @@ -121,7 +177,7 @@ def @main(%a : Tensor[(1, 2), uint8], %b : Tensor[(1, 2), float32], %c : Tensor[ if executor == "graph": validate_graph_json(extract_dir, factory) - with open(os.path.join(extract_dir, "relay.txt")) as relay_f: + with open(os.path.join(extract_dir, "src", "relay.txt")) as relay_f: assert relay_f.read() == str(relay_mod) with open(os.path.join(extract_dir, "parameters", "add.params"), "rb") as params_f: @@ -165,7 +221,7 @@ def @main(%a : Tensor[(1, 2), uint8], %b : Tensor[(1, 2), float32], %c : Tensor[ with open(os.path.join(extract_dir, "metadata.json")) as json_f: metadata = json.load(json_f) - assert metadata["version"] == 3 + assert metadata["version"] == 4 assert metadata["model_name"] == "add" export_datetime = datetime.datetime.strptime( metadata["export_datetime"], "%Y-%m-%d %H:%M:%SZ" @@ -198,7 +254,7 @@ def @main(%a : Tensor[(1, 2), uint8], %b : Tensor[(1, 2), float32], %c : Tensor[ validate_graph_json(extract_dir, factory) - with open(os.path.join(extract_dir, "relay.txt")) as relay_f: + with open(os.path.join(extract_dir, "src", "relay.txt")) as relay_f: assert relay_f.read() == str(relay_mod) with open(os.path.join(extract_dir, "parameters", "add.params"), "rb") as params_f: @@ -244,7 +300,7 @@ def @main(%p0: Tensor[(1, 56, 56, 128), int16], %p1: Tensor[(3, 3, 128, 1), int1 with open(os.path.join(extract_dir, "metadata.json")) as json_f: metadata = json.load(json_f) - assert metadata["version"] == 3 + assert metadata["version"] == 4 assert metadata["model_name"] == "qnn_conv2d" export_datetime = datetime.datetime.strptime( metadata["export_datetime"], "%Y-%m-%d %H:%M:%SZ" @@ -269,7 +325,7 @@ def @main(%p0: Tensor[(1, 56, 56, 128), int16], %p1: Tensor[(3, 3, 128, 1), int1 @tvm.testing.requires_micro -def test_export_model(): +def test_export_non_dso_exportable(): module = tvm.support.FrontendTestModule() factory = executor_factory.GraphExecutorFactoryModule( None, tvm.target.target.micro("host"), '"graph_json"', module, "test_module", {}, {}