From d37518e8eff9a6a18f6ce3a98212f28a17863260 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 3 Apr 2024 08:20:35 -0500 Subject: [PATCH 1/3] [Runtime] Allow inspection of function names from a compiled .so Prior to this commit, the `LLVMModuleNode` provided a `"get_func_names"` function that would return the names of functions available within the result of `tvm.build`. However, this utility was not preserved across a round-trip through `mod.export_library` and `tvm.runtime.load_module`. This commit adds a similar `"__get_func_names"` function to a `DSOLibrary`, which returns the symbols available for use in the library. This is exposed in the `Module.keys()` method, mimicking the interface that a Python user would expect. In addition, this is used to improve the error message shown when a library does not contain the requested function. The `difflib` module (from Python's stdlib) is used to find functions that are contained in the module, and have similar names to the one requested. --- python/tvm/runtime/module.py | 60 +++++++++++-- src/runtime/dso_library.cc | 78 +++++++++++++++- src/runtime/library_module.cc | 8 ++ src/runtime/library_module.h | 7 ++ src/target/llvm/llvm_module.cc | 2 +- .../codegen/test_target_codegen_llvm.py | 90 ++++++++++++++++++- 6 files changed, 235 insertions(+), 10 deletions(-) diff --git a/python/tvm/runtime/module.py b/python/tvm/runtime/module.py index 2c3eff700009..f8cd34a3249a 100644 --- a/python/tvm/runtime/module.py +++ b/python/tvm/runtime/module.py @@ -17,10 +17,14 @@ # pylint: disable=invalid-name, unused-import, import-outside-toplevel, inconsistent-return-statements """Runtime Module namespace.""" -import os + import ctypes +import difflib +import os import struct -from typing import Sequence + +from typing import Sequence, Iterable, Iterator + import numpy as np from tvm._ffi.base import _LIB, check_call, c_str, string_types, _RUNTIME_ONLY @@ -173,7 +177,19 @@ def get_function(self, name, query_imports=False): ) ) if not ret_handle.value: - raise AttributeError(f"Module has no function '{name}'") + nearby_names = difflib.get_close_matches(name, self.keys()) + if nearby_names: + message = ( + f"Module has no function '{name}'. " + f"The module does contain functions with similar names: " + f"{nearby_names}." + ) + else: + message = ( + f"Module has no function '{name}'. " + f"The module does not contain any function with a similar name." + ) + raise KeyError(message) return PackedFunc(ret_handle, False) def import_module(self, module): @@ -186,11 +202,45 @@ def import_module(self, module): """ check_call(_LIB.TVMModImport(self.handle, module.handle)) - def __getitem__(self, name): + def __getitem__(self, name: str) -> PackedFunc: + """Return the PackedFunc associated with the gven name + + Parameters + ---------- + name: str + The name of the function to be returned + + Returns + ------- + PackedFunc + + """ + if not isinstance(name, string_types): - raise ValueError("Can only take string as function name") + raise TypeError(f"Module.__getitem__ expects a string, but received {type(name)}") return self.get_function(name) + def __contains__(self, key: str) -> bool: + return key in self.keys() + + def keys(self) -> Sequence[str]: + """Return a list of functions in the module + + Returns + ------- + Sequence[str] + The functions in the module + """ + return self["__get_func_names"]() + + def values(self) -> Iterator[PackedFunc]: + for key in self.keys(): + yield self[key] + + def items(self) -> Iterator[PackedFunc]: + for key in self.keys(): + yield key, self[key] + def __eq__(self, other): return self.handle.value == other.handle.value diff --git a/src/runtime/dso_library.cc b/src/runtime/dso_library.cc index e4f4937a8a9f..095afaa1c725 100644 --- a/src/runtime/dso_library.cc +++ b/src/runtime/dso_library.cc @@ -32,6 +32,8 @@ #include #else #include +#include +#include #endif #if defined(__hexagon__) @@ -51,6 +53,7 @@ namespace runtime { class DSOLibrary final : public Library { public: ~DSOLibrary(); + /*! * \brief Initialize by loading and storing * a handle to the underlying shared library. @@ -58,6 +61,7 @@ class DSOLibrary final : public Library { * shared library over which to initialize. */ void Init(const std::string& name); + /*! * \brief Returns the symbol address within * the shared library for a given symbol name. @@ -66,18 +70,40 @@ class DSOLibrary final : public Library { */ void* GetSymbol(const char* name) final; + /*! \brief List symbols available within the module + * + * \param callback The callback to be executed for each symbol in + * the library. + */ + void ListSymbols(std::function callback) final; + private: /*! \brief Private implementation of symbol lookup. - * Implementation is operating system dependent. - * \param The name of the symbol. + * + * Implementation is operating system dependent. + * + * \param name The name of the symbol. + * * \return The symbol. */ void* GetSymbol_(const char* name); + + /*! \brief Private implementation of symbol lookup. + * + * Implementation is operating system dependent. + * + * \param callback The callback for each symbol located. + * + * \return The symbol. + */ + void ListSymbols_(std::function callback); + /*! \brief Implementation of shared library load. * Implementation is operating system dependent. * \param The name/path of the shared library. */ void Load(const std::string& name); + /*! \brief Implementation of shared library unload. * Implementation is operating system dependent. */ @@ -100,12 +126,16 @@ void DSOLibrary::Init(const std::string& name) { Load(name); } void* DSOLibrary::GetSymbol(const char* name) { return GetSymbol_(name); } +void DSOLibrary::ListSymbols(std::function callback) { ListSymbols_(callback); } + #if defined(_WIN32) void* DSOLibrary::GetSymbol_(const char* name) { return reinterpret_cast(GetProcAddress(lib_handle_, (LPCSTR)name)); // NOLINT(*) } +void DSOLibrary::ListSymbols_(std::function callback) {} + void DSOLibrary::Load(const std::string& name) { // use wstring version that is needed by LLVM. std::wstring wname(name.begin(), name.end()); @@ -136,6 +166,50 @@ void DSOLibrary::Load(const std::string& name) { void* DSOLibrary::GetSymbol_(const char* name) { return dlsym(lib_handle_, name); } +void DSOLibrary::ListSymbols_(std::function callback) { + // Adapted from https://stackoverflow.com/a/62205128 + + struct link_map* map = nullptr; + dlinfo(lib_handle_, RTLD_DI_LINKMAP, &map); + + Elf64_Sym* symbol_table = nullptr; + char* string_table = nullptr; + int entry_size = 0; + for (auto section = map->l_ld; section->d_tag != DT_NULL; ++section) { + if (section->d_tag == DT_SYMTAB) { + symbol_table = (Elf64_Sym*)section->d_un.d_ptr; + } else if (section->d_tag == DT_STRTAB) { + string_table = (char*)section->d_un.d_ptr; + } else if (section->d_tag == DT_SYMENT) { + entry_size = section->d_un.d_val; + } + } + + CHECK(symbol_table) << "RuntimeError: " + << "Malformed ELF binary '" << map->l_name + << "', no symbol table (DT_SYMTAB) found"; + CHECK(string_table) << "RuntimeError: " + << "Malformed ELF binary '" << map->l_name + << "', no string table (DT_STRTAB) found"; + + int symbol_table_size = string_table - (char*)symbol_table; + int num_symbols = symbol_table_size / entry_size; + for (int i = 0; i < num_symbols; i++) { + Elf64_Sym* symbol = &symbol_table[i]; + + if ( + // If the symbol is a function + ELF64_ST_TYPE(symbol->st_info) == STT_FUNC && + // defined with global linkage + ELF64_ST_BIND(symbol->st_info) == STB_GLOBAL && + // and visible to external modules. + ELF64_ST_VISIBILITY(symbol->st_other) == STV_DEFAULT) { + const char* symbol_name = &string_table[symbol->st_name]; + callback(symbol_name); + } + } +} + void DSOLibrary::Unload() { dlclose(lib_handle_); lib_handle_ = nullptr; diff --git a/src/runtime/library_module.cc b/src/runtime/library_module.cc index 7b39bcd8da02..4bf9431fbfca 100644 --- a/src/runtime/library_module.cc +++ b/src/runtime/library_module.cc @@ -48,6 +48,14 @@ class LibraryModuleNode final : public ModuleNode { }; PackedFunc GetFunction(const String& name, const ObjectPtr& sptr_to_self) final { + if (name == "__get_func_names") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + Array names; + lib_->ListSymbols([&names](const char* symbol) { names.push_back(symbol); }); + *rv = std::move(names); + }); + } + TVMBackendPackedCFunc faddr; if (name == runtime::symbol::tvm_module_main) { const char* entry_name = diff --git a/src/runtime/library_module.h b/src/runtime/library_module.h index d4d32abe2110..8648cd2ba1c8 100644 --- a/src/runtime/library_module.h +++ b/src/runtime/library_module.h @@ -64,6 +64,13 @@ class Library : public Object { virtual void* GetSymbol(const char* name) = 0; // NOTE: we do not explicitly create an type index and type_key here for libary. // This is because we do not need dynamic type downcasting. + + /*! \brief List symbols available within the module + * + * \param callback The callback to be executed for each symbol in + * the library. + */ + virtual void ListSymbols(std::function callback) {} }; /*! diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc index baa68feedfa2..56705bd475e6 100644 --- a/src/target/llvm/llvm_module.cc +++ b/src/target/llvm/llvm_module.cc @@ -171,7 +171,7 @@ PackedFunc LLVMModuleNode::GetFunction(const String& name, const ObjectPtrfunction_names_; }); } else if (name == "get_symbol") { diff --git a/tests/python/codegen/test_target_codegen_llvm.py b/tests/python/codegen/test_target_codegen_llvm.py index f1316ae3cee0..a54776d6ccb3 100644 --- a/tests/python/codegen/test_target_codegen_llvm.py +++ b/tests/python/codegen/test_target_codegen_llvm.py @@ -14,14 +14,19 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + import collections import ctypes import json import math -import numpy as np -import pytest +import pathlib import re import sys +import tempfile + +import numpy as np +import pytest + import tvm import tvm.testing @@ -1109,5 +1114,86 @@ def func(): built = tvm.build(func, target="llvm") +save_and_reload = tvm.testing.parameter( + by_dict={ + "llvm_module": False, + "library_module": True, + } +) + + +@tvm.testing.requires_llvm +def test_inspect_module_contents(save_and_reload: bool): + @I.ir_module + class mod: + @T.prim_func + def main(A: T.Buffer(1, dtype="float32")): + mod.subroutine(A.data) + + @T.prim_func(private=True) + def subroutine(A_data: T.handle("float32")): + A = T.decl_buffer(1, dtype="float32", data=A_data) + A[0] = 42.0 + + @T.prim_func + def main_2(A: T.Buffer(1, dtype="float32")): + mod.subroutine(A.data) + + mod = tvm.build(mod, target="llvm") + + if save_and_reload: + with tempfile.TemporaryDirectory() as temp_dir: + temp_dir = pathlib.Path(temp_dir) + model_path = temp_dir.joinpath("lib.so") + mod.export_library(model_path) + mod = tvm.runtime.load_module(model_path) + + assert "main" in mod + assert "main_2" in mod + assert "subroutine" not in mod + + assert set(mod.keys()) == set(["main", "main_2"]) + + +@tvm.testing.requires_llvm +def test_suggest_nearby_name(save_and_reload: bool): + @I.ir_module + class mod: + @T.prim_func + def main(): + pass + + @T.prim_func + def long_tedious_name_that_would_be_easy_to_misspell(): + pass + + mod = tvm.build(mod, target="llvm") + + if save_and_reload: + with tempfile.TemporaryDirectory() as temp_dir: + temp_dir = pathlib.Path(temp_dir) + model_path = temp_dir.joinpath("lib.so") + mod.export_library(model_path) + mod = tvm.runtime.load_module(model_path) + + mod["main"] + mod["long_tedious_name_that_would_be_easy_to_misspell"] + + with pytest.raises(KeyError, match=r"similar names: \['main'\]"): + mod["mian"] + + with pytest.raises( + KeyError, + match=r"similar names: \['long_tedious_name_that_would_be_easy_to_misspell'\]", + ): + mod["long_tedious_name_that_would_be_easy_to_mispell"] + + with pytest.raises( + KeyError, + match=r"does not contain any function with a similar name", + ): + mod["unrelated_name_that_is_different_from_any_valid_name"] + + if __name__ == "__main__": tvm.testing.main() From 516e33022de3f73321a51964b00e5f0d1ee8b593 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 3 Apr 2024 13:48:17 -0500 Subject: [PATCH 2/3] Handle return of `T.StringImm` from PrimFunc --- include/tvm/runtime/c_backend_api.h | 12 ++++++ include/tvm/runtime/packed_func.h | 3 +- python/tvm/tir/op.py | 3 ++ src/runtime/c_runtime_api.cc | 6 +++ src/tir/transforms/make_packed_api.cc | 4 ++ .../codegen/test_target_codegen_llvm.py | 38 ++++++++++++++----- 6 files changed, 55 insertions(+), 11 deletions(-) diff --git a/include/tvm/runtime/c_backend_api.h b/include/tvm/runtime/c_backend_api.h index 8fde5948f993..82d420bd8b47 100644 --- a/include/tvm/runtime/c_backend_api.h +++ b/include/tvm/runtime/c_backend_api.h @@ -166,6 +166,18 @@ TVM_DLL int TVMBackendParallelBarrier(int task_id, TVMParallelGroupEnv* penv); */ TVM_DLL int TVMBackendRunOnce(void** handle, int (*f)(void*), void* cdata, int nbytes); +/*! \brief Generate a pointer usable as kTVMStr in a TVMRetValue + * + * While TVMArgValue uses `const char*` to represent string arguments, + * TVMRetValue represents string return values as a heap-allocated C++ + * `std::string` objects. + * + * \param c_str A null-terminated C-style string + * + * \return A pointer to the heap-allocated C++ `std::string`. + */ +void* TVMBackendStringRetValue(const char* c_str); + #ifdef __cplusplus } // TVM_EXTERN_C #endif diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index 7266f8c4a50a..e27e9d59fd98 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -974,7 +974,8 @@ class TVMRetValue : public TVMPODValue_ { */ static TVMRetValue MoveFromCHost(TVMValue value, int type_code) { // Can move POD and everything under the object system. - ICHECK(type_code <= kTVMPackedFuncHandle || type_code == kTVMNDArrayHandle); + ICHECK(type_code <= kTVMPackedFuncHandle || type_code == kTVMNDArrayHandle || + type_code == kTVMStr); TVMRetValue ret; ret.value_ = value; ret.type_code_ = type_code; diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index 8816880e7b52..bce016d235dc 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -1639,6 +1639,9 @@ def ret(val): """ val = convert(val) + if isinstance(val, tvm.runtime.String): + val = StringImm(val) + return call_intrin(val.dtype, "tir.ret", val) diff --git a/src/runtime/c_runtime_api.cc b/src/runtime/c_runtime_api.cc index ea22b89dd771..11c178b110f6 100644 --- a/src/runtime/c_runtime_api.cc +++ b/src/runtime/c_runtime_api.cc @@ -555,6 +555,12 @@ int TVMBackendRunOnce(void** handle, int (*f)(void*), void* cdata, int nbytes) { return 0; } +void* TVMBackendStringRetValue(const char* c_str) { + // The `std::string` allocated here is immediately stored in a + // `TVMRetValue`, which takes ownership of the object. + return new std::string(c_str); +} + int TVMFuncFree(TVMFunctionHandle func) { return TVMObjectFree(func); } int TVMByteArrayFree(TVMByteArray* arr) { diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index bf1f3a9e7fd2..cc7c14fde5e1 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -90,6 +90,10 @@ class ReturnRewriter : public StmtMutator { } else if (dtype.is_void()) { info.tcode = kTVMNullptr; info.expr = val; + } else if (val->IsInstance()) { + info.tcode = kTVMStr; + info.expr = Call(DataType::Handle(), builtin::call_pure_extern(), + {StringImm("TVMBackendStringRetValue"), val}); } else { LOG(FATAL) << "data type " << dtype << " not supported yet"; } diff --git a/tests/python/codegen/test_target_codegen_llvm.py b/tests/python/codegen/test_target_codegen_llvm.py index a54776d6ccb3..a9694030e5d7 100644 --- a/tests/python/codegen/test_target_codegen_llvm.py +++ b/tests/python/codegen/test_target_codegen_llvm.py @@ -1122,6 +1122,32 @@ def func(): ) +def _save_and_reload(mod): + with tempfile.TemporaryDirectory() as temp_dir: + temp_dir = pathlib.Path(temp_dir) + model_path = temp_dir.joinpath("lib.so") + mod.export_library(model_path) + mod = tvm.runtime.load_module(model_path) + return mod + + +@tvm.testing.requires_llvm +def test_return_string(save_and_reload: bool): + @I.ir_module + class mod: + @T.prim_func + def main(): + return "my string here" + + mod = tvm.build(mod) + + if save_and_reload: + mod = _save_and_reload(mod) + + res = mod() + assert res == "my string here" + + @tvm.testing.requires_llvm def test_inspect_module_contents(save_and_reload: bool): @I.ir_module @@ -1142,11 +1168,7 @@ def main_2(A: T.Buffer(1, dtype="float32")): mod = tvm.build(mod, target="llvm") if save_and_reload: - with tempfile.TemporaryDirectory() as temp_dir: - temp_dir = pathlib.Path(temp_dir) - model_path = temp_dir.joinpath("lib.so") - mod.export_library(model_path) - mod = tvm.runtime.load_module(model_path) + mod = _save_and_reload(mod) assert "main" in mod assert "main_2" in mod @@ -1170,11 +1192,7 @@ def long_tedious_name_that_would_be_easy_to_misspell(): mod = tvm.build(mod, target="llvm") if save_and_reload: - with tempfile.TemporaryDirectory() as temp_dir: - temp_dir = pathlib.Path(temp_dir) - model_path = temp_dir.joinpath("lib.so") - mod.export_library(model_path) - mod = tvm.runtime.load_module(model_path) + mod = _save_and_reload(mod) mod["main"] mod["long_tedious_name_that_would_be_easy_to_misspell"] From 41a06622eb40b483145179cbde7f2de8a81d935d Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 3 Apr 2024 13:20:06 -0500 Subject: [PATCH 3/3] Re-implement in terms of a lowering pass --- include/tvm/runtime/module.h | 2 + include/tvm/tir/transform.h | 7 + python/tvm/runtime/module.py | 76 +++++-- python/tvm/tir/transform/transform.py | 16 ++ src/driver/driver_api.cc | 18 +- src/runtime/dso_library.cc | 78 +------ src/runtime/library_module.cc | 8 - src/runtime/library_module.h | 7 - src/target/llvm/llvm_module.cc | 2 +- src/target/source/codegen_c.cc | 23 +- src/target/source/codegen_c_host.cc | 69 +++++- src/target/source/codegen_c_host.h | 18 ++ .../generate_function_signature_metadata.cc | 98 +++++++++ ...rm_generate_function_signature_metadata.py | 199 ++++++++++++++++++ 14 files changed, 511 insertions(+), 110 deletions(-) create mode 100644 src/tir/transforms/generate_function_signature_metadata.cc create mode 100644 tests/python/tir-transform/test_tir_transform_generate_function_signature_metadata.py diff --git a/include/tvm/runtime/module.h b/include/tvm/runtime/module.h index 308c936d624a..a3b85d517e1f 100644 --- a/include/tvm/runtime/module.h +++ b/include/tvm/runtime/module.h @@ -298,6 +298,8 @@ constexpr const char* tvm_module_main = "__tvm_main__"; constexpr const char* tvm_param_prefix = "__tvm_param__"; /*! \brief A PackedFunc that looks up linked parameters by storage_id. */ constexpr const char* tvm_lookup_linked_param = "_lookup_linked_param"; +/*! \brief Placeholder for the module's entry function. */ +constexpr const char* tvm_get_tir_function_metadata = "__tvm_get_tir_function_metadata__"; /*! \brief Model entrypoint generated as an interface to the AOT function outside of TIR */ constexpr const char* tvm_entrypoint_suffix = "run"; } // namespace symbol diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index 98edbeaceb26..2ef6f7c8c9d2 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -263,6 +263,13 @@ TVM_DLL Pass LowerCustomDatatypes(); */ TVM_DLL Pass DecorateDeviceScope(); +/*! + * \brief Generate module metadata describing function signatures + * + * \return The pass. + */ +TVM_DLL Pass GenerateFunctionSignatureMetadata(); + /*! * \brief Annotate locations that should be run on the device * diff --git a/python/tvm/runtime/module.py b/python/tvm/runtime/module.py index f8cd34a3249a..b432022e8f45 100644 --- a/python/tvm/runtime/module.py +++ b/python/tvm/runtime/module.py @@ -20,10 +20,11 @@ import ctypes import difflib +import json import os import struct -from typing import Sequence, Iterable, Iterator +from typing import Sequence, Iterable, Iterator, Optional import numpy as np @@ -104,7 +105,20 @@ class ModulePropertyMask(object): class Module(object): """Runtime Module.""" - __slots__ = ["handle", "_entry", "entry_name"] + __slots__ = ["handle", "_entry", "entry_name", "_cached_metadata"] + + _GET_TIR_FUNCTION_METADATA = "__tvm_get_tir_function_metadata__" + """The function to retrieve TIR metadata, if any + + This is for internal use only. If present in the module, this + should be a function that returns a JSON string describing the + functions provided by the module, along with the signatures of + those functions. + + This parameter corresponds to the C++ value + `tvm::runtime::symbol::tvm_get_tir_function_metadata`, located in + `#include `. + """ def __init__(self, handle): self.handle = handle @@ -167,16 +181,11 @@ def get_function(self, name, query_imports=False): Returns ------- - f : tvm.runtime.PackedFunc + func : tvm.runtime.PackedFunc The result function. """ - ret_handle = PackedFuncHandle() - check_call( - _LIB.TVMModGetFunction( - self.handle, c_str(name), ctypes.c_int(query_imports), ctypes.byref(ret_handle) - ) - ) - if not ret_handle.value: + func = self._get_function(name, query_imports=query_imports) + if func is None: nearby_names = difflib.get_close_matches(name, self.keys()) if nearby_names: message = ( @@ -190,7 +199,36 @@ def get_function(self, name, query_imports=False): f"The module does not contain any function with a similar name." ) raise KeyError(message) - return PackedFunc(ret_handle, False) + + return func + + def _get_function(self, name, query_imports=False) -> Optional[PackedFunc]: + """Get function from the module. + + Parameters + ---------- + name : str + The name of the function + + query_imports : bool + Whether also query modules imported by this module. + + Returns + ------- + func : Optional[tvm.runtime.PackedFunc] + The result function, or None if it cannot be found + """ + ret_handle = PackedFuncHandle() + check_call( + _LIB.TVMModGetFunction( + self.handle, c_str(name), ctypes.c_int(query_imports), ctypes.byref(ret_handle) + ) + ) + # pylint: disable=using-constant-test + if ret_handle.value: + return PackedFunc(ret_handle, False) + else: + return None def import_module(self, module): """Add module to the import list of current one. @@ -223,6 +261,19 @@ def __getitem__(self, name: str) -> PackedFunc: def __contains__(self, key: str) -> bool: return key in self.keys() + @property + def _metadata(self): + if hasattr(self, "_cached_metadata"): + # pylint: disable=access-member-before-definition + return self._cached_metadata + + metadata_func = self._get_function(Module._GET_TIR_FUNCTION_METADATA) + if metadata_func is None: + raise RuntimeError("Cannot find function metadata in runtime.Module") + + self._cached_metadata = json.loads(metadata_func()) + return self._cached_metadata + def keys(self) -> Sequence[str]: """Return a list of functions in the module @@ -231,7 +282,8 @@ def keys(self) -> Sequence[str]: Sequence[str] The functions in the module """ - return self["__get_func_names"]() + for function in self._metadata["functions"]: + yield function def values(self) -> Iterator[PackedFunc]: for key in self.keys(): diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index c2022b918643..6ffd6974f30c 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -488,6 +488,22 @@ def MakeUnpackedAPI(): return _ffi_api.MakeUnpackedAPI() # type: ignore +def GenerateFunctionSignatureMetadata(): + """Generate metadata describing the function signatures + + Generate a metadata function that returns a JSON-formatted string, + describing the functions available within the IRModule and their + signatures. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + + """ + return _ffi_api.GenerateFunctionSignatureMetadata() # type: ignore + + def AnnotateDeviceRegions(): """Annotate locations that should be run on the device diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 7ea5032fa0cc..4fcd86d85417 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -569,8 +569,22 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target) Array mixed_pass_list; - // FPComputeLegalize uses the target attrs added by BindTarget, so it must come first + // AnnotateEntryFunc inspects user-defined functions to provide a + // default function to call. Therefore, it should appear before any + // passes that would generate additional PrimFuncs. + mixed_pass_list.push_back(tir::transform::AnnotateEntryFunc()); + + // GenerateFunctionSignatureMetadata produces a function with + // `tir.is_host_func`, relying on `BindTarget` to generate the + // correct target annotation. Therefore, it must come before + // BindTarget. + mixed_pass_list.push_back(tir::transform::GenerateFunctionSignatureMetadata()); + + // Many later passes, such as FP8ComputeLegalize, use the target + // attrs added by BindTarget. Therefore, BindTarget should occur as + // early as possible. mixed_pass_list.push_back(tir::transform::BindTarget(target)); + mixed_pass_list.push_back(tir::transform::FP8ComputeLegalize()); // VerifyVTCMLimit must occur before LowerVtcmAlloc @@ -580,8 +594,6 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target) mixed_pass_list.push_back(tir::transform::VerifyMemory()); - mixed_pass_list.push_back(tir::transform::AnnotateEntryFunc()); - bool detect_global_barrier = pass_ctx->GetConfig("tir.detect_global_barrier", Bool(false)).value(); if (detect_global_barrier) { diff --git a/src/runtime/dso_library.cc b/src/runtime/dso_library.cc index 095afaa1c725..e4f4937a8a9f 100644 --- a/src/runtime/dso_library.cc +++ b/src/runtime/dso_library.cc @@ -32,8 +32,6 @@ #include #else #include -#include -#include #endif #if defined(__hexagon__) @@ -53,7 +51,6 @@ namespace runtime { class DSOLibrary final : public Library { public: ~DSOLibrary(); - /*! * \brief Initialize by loading and storing * a handle to the underlying shared library. @@ -61,7 +58,6 @@ class DSOLibrary final : public Library { * shared library over which to initialize. */ void Init(const std::string& name); - /*! * \brief Returns the symbol address within * the shared library for a given symbol name. @@ -70,40 +66,18 @@ class DSOLibrary final : public Library { */ void* GetSymbol(const char* name) final; - /*! \brief List symbols available within the module - * - * \param callback The callback to be executed for each symbol in - * the library. - */ - void ListSymbols(std::function callback) final; - private: /*! \brief Private implementation of symbol lookup. - * - * Implementation is operating system dependent. - * - * \param name The name of the symbol. - * + * Implementation is operating system dependent. + * \param The name of the symbol. * \return The symbol. */ void* GetSymbol_(const char* name); - - /*! \brief Private implementation of symbol lookup. - * - * Implementation is operating system dependent. - * - * \param callback The callback for each symbol located. - * - * \return The symbol. - */ - void ListSymbols_(std::function callback); - /*! \brief Implementation of shared library load. * Implementation is operating system dependent. * \param The name/path of the shared library. */ void Load(const std::string& name); - /*! \brief Implementation of shared library unload. * Implementation is operating system dependent. */ @@ -126,16 +100,12 @@ void DSOLibrary::Init(const std::string& name) { Load(name); } void* DSOLibrary::GetSymbol(const char* name) { return GetSymbol_(name); } -void DSOLibrary::ListSymbols(std::function callback) { ListSymbols_(callback); } - #if defined(_WIN32) void* DSOLibrary::GetSymbol_(const char* name) { return reinterpret_cast(GetProcAddress(lib_handle_, (LPCSTR)name)); // NOLINT(*) } -void DSOLibrary::ListSymbols_(std::function callback) {} - void DSOLibrary::Load(const std::string& name) { // use wstring version that is needed by LLVM. std::wstring wname(name.begin(), name.end()); @@ -166,50 +136,6 @@ void DSOLibrary::Load(const std::string& name) { void* DSOLibrary::GetSymbol_(const char* name) { return dlsym(lib_handle_, name); } -void DSOLibrary::ListSymbols_(std::function callback) { - // Adapted from https://stackoverflow.com/a/62205128 - - struct link_map* map = nullptr; - dlinfo(lib_handle_, RTLD_DI_LINKMAP, &map); - - Elf64_Sym* symbol_table = nullptr; - char* string_table = nullptr; - int entry_size = 0; - for (auto section = map->l_ld; section->d_tag != DT_NULL; ++section) { - if (section->d_tag == DT_SYMTAB) { - symbol_table = (Elf64_Sym*)section->d_un.d_ptr; - } else if (section->d_tag == DT_STRTAB) { - string_table = (char*)section->d_un.d_ptr; - } else if (section->d_tag == DT_SYMENT) { - entry_size = section->d_un.d_val; - } - } - - CHECK(symbol_table) << "RuntimeError: " - << "Malformed ELF binary '" << map->l_name - << "', no symbol table (DT_SYMTAB) found"; - CHECK(string_table) << "RuntimeError: " - << "Malformed ELF binary '" << map->l_name - << "', no string table (DT_STRTAB) found"; - - int symbol_table_size = string_table - (char*)symbol_table; - int num_symbols = symbol_table_size / entry_size; - for (int i = 0; i < num_symbols; i++) { - Elf64_Sym* symbol = &symbol_table[i]; - - if ( - // If the symbol is a function - ELF64_ST_TYPE(symbol->st_info) == STT_FUNC && - // defined with global linkage - ELF64_ST_BIND(symbol->st_info) == STB_GLOBAL && - // and visible to external modules. - ELF64_ST_VISIBILITY(symbol->st_other) == STV_DEFAULT) { - const char* symbol_name = &string_table[symbol->st_name]; - callback(symbol_name); - } - } -} - void DSOLibrary::Unload() { dlclose(lib_handle_); lib_handle_ = nullptr; diff --git a/src/runtime/library_module.cc b/src/runtime/library_module.cc index 4bf9431fbfca..7b39bcd8da02 100644 --- a/src/runtime/library_module.cc +++ b/src/runtime/library_module.cc @@ -48,14 +48,6 @@ class LibraryModuleNode final : public ModuleNode { }; PackedFunc GetFunction(const String& name, const ObjectPtr& sptr_to_self) final { - if (name == "__get_func_names") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - Array names; - lib_->ListSymbols([&names](const char* symbol) { names.push_back(symbol); }); - *rv = std::move(names); - }); - } - TVMBackendPackedCFunc faddr; if (name == runtime::symbol::tvm_module_main) { const char* entry_name = diff --git a/src/runtime/library_module.h b/src/runtime/library_module.h index 8648cd2ba1c8..d4d32abe2110 100644 --- a/src/runtime/library_module.h +++ b/src/runtime/library_module.h @@ -64,13 +64,6 @@ class Library : public Object { virtual void* GetSymbol(const char* name) = 0; // NOTE: we do not explicitly create an type index and type_key here for libary. // This is because we do not need dynamic type downcasting. - - /*! \brief List symbols available within the module - * - * \param callback The callback to be executed for each symbol in - * the library. - */ - virtual void ListSymbols(std::function callback) {} }; /*! diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc index 56705bd475e6..baa68feedfa2 100644 --- a/src/target/llvm/llvm_module.cc +++ b/src/target/llvm/llvm_module.cc @@ -171,7 +171,7 @@ PackedFunc LLVMModuleNode::GetFunction(const String& name, const ObjectPtrfunction_names_; }); } else if (name == "get_symbol") { diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index 009fc1672ace..dcb870fad30e 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -451,7 +451,28 @@ void CodeGenC::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NOLINT PrintConst(op, os, this); } void CodeGenC::VisitExpr_(const StringImmNode* op, std::ostream& os) { // NOLINT(*) - os << "\"" << op->value << "\""; + const auto& str = op->value; + os << '"'; + for (size_t i = 0; i < str.size(); i++) { + char c = str.c_str()[i]; + switch (c) { + case '\n': + os << "\\n"; + break; + + case '\\': + case '"': + case '?': + os << '\\' << c; + break; + + default: + os << c; + break; + } + } + + os << '"'; } template diff --git a/src/target/source/codegen_c_host.cc b/src/target/source/codegen_c_host.cc index b22d32d6c5e3..2846108a87fb 100644 --- a/src/target/source/codegen_c_host.cc +++ b/src/target/source/codegen_c_host.cc @@ -35,6 +35,7 @@ #include #include "../../support/str_escape.h" +#include "../../support/utils.h" #include "../build_common.h" #include "../func_registry_generator.h" #include "codegen_params.h" @@ -51,8 +52,8 @@ void CodeGenCHost::Init(bool output_ssa, bool emit_asserts, bool emit_fwd_func_d declared_globals_.clear(); decl_stream << "// tvm target: " << target_str << "\n"; decl_stream << "#define TVM_EXPORTS\n"; - decl_stream << "#include \"tvm/runtime/c_runtime_api.h\"\n"; - decl_stream << "#include \"tvm/runtime/c_backend_api.h\"\n"; + DeclareIncludeTVMRuntimeAPI(); + DeclareIncludeTVMBackendAPI(); decl_stream << "#include \n"; decl_stream << "#include \n"; if (devices.find("ethos-u") != devices.end()) { @@ -69,6 +70,66 @@ void CodeGenCHost::Init(bool output_ssa, bool emit_asserts, bool emit_fwd_func_d CodeGenC::Init(output_ssa); } +void CodeGenCHost::DeclareIncludeTVMRuntimeAPI() { + decl_stream << "#include \"tvm/runtime/c_runtime_api.h\"\n"; + included_function_names_.insert("TVMAPISetLastError"); + included_function_names_.insert("TVMAPISetLastPythonError"); + included_function_names_.insert("TVMGetLastPythonError"); + included_function_names_.insert("TVMGetLastError"); + included_function_names_.insert("TVMGetLastBacktrace"); + included_function_names_.insert("TVMDropLastPythonError"); + included_function_names_.insert("TVMThrowLastError"); + included_function_names_.insert("TVMModLoadFromFile"); + included_function_names_.insert("TVMModImport"); + included_function_names_.insert("TVMModGetFunction"); + included_function_names_.insert("TVMModFree"); + included_function_names_.insert("TVMFuncFree"); + included_function_names_.insert("TVMFuncCall"); + included_function_names_.insert("TVMCFuncSetReturn"); + included_function_names_.insert("TVMCbArgToReturn"); + included_function_names_.insert("TVMFuncCreateFromCFunc"); + included_function_names_.insert("TVMFuncRegisterGlobal"); + included_function_names_.insert("TVMFuncGetGlobal"); + included_function_names_.insert("TVMFuncListGlobalNames"); + included_function_names_.insert("TVMFuncRemoveGlobal"); + included_function_names_.insert("TVMArrayAlloc"); + included_function_names_.insert("TVMArrayFree"); + included_function_names_.insert("TVMArrayCopyFromBytes"); + included_function_names_.insert("TVMArrayCopyToBytes"); + included_function_names_.insert("TVMArrayCopyFromTo"); + included_function_names_.insert("TVMArrayFromDLPack"); + included_function_names_.insert("TVMArrayToDLPack"); + included_function_names_.insert("TVMDLManagedTensorCallDeleter"); + included_function_names_.insert("TVMStreamCreate"); + included_function_names_.insert("TVMStreamFree"); + included_function_names_.insert("TVMSetStream"); + included_function_names_.insert("TVMSynchronize"); + included_function_names_.insert("TVMStreamStreamSynchronize"); + included_function_names_.insert("TVMObjectGetTypeIndex"); + included_function_names_.insert("TVMObjectTypeKey2Index"); + included_function_names_.insert("TVMObjectTypeIndex2Key"); + included_function_names_.insert("TVMObjectRetain"); + included_function_names_.insert("TVMObjectFree"); + included_function_names_.insert("TVMByteArrayFree"); + included_function_names_.insert("TVMDeviceAllocDataSpace"); + included_function_names_.insert("TVMDeviceAllocDataSpaceWithScope"); + included_function_names_.insert("TVMDeviceFreeDataSpace"); + included_function_names_.insert("TVMDeviceCopyDataFromTo"); + included_function_names_.insert("TVMObjectDerivedFrom"); +} + +void CodeGenCHost::DeclareIncludeTVMBackendAPI() { + decl_stream << "#include \"tvm/runtime/c_backend_api.h\"\n"; + included_function_names_.insert("TVMBackendGetFuncFromEnv"); + included_function_names_.insert("TVMBackendRegisterSystemLibSymbol"); + included_function_names_.insert("TVMBackendAllocWorkspace"); + included_function_names_.insert("TVMBackendFreeWorkspace"); + included_function_names_.insert("TVMBackendRegisterEnvCAPI"); + included_function_names_.insert("TVMBackendParallelLaunch"); + included_function_names_.insert("TVMBackendParallelBarrier"); + included_function_names_.insert("TVMBackendRunOnce"); +} + void CodeGenCHost::InitGlobalContext() { decl_stream << "void* " << tvm::runtime::symbol::tvm_module_ctx << " = NULL;\n"; } @@ -118,6 +179,10 @@ void CodeGenCHost::GenerateForwardFunctionDeclarations(String global_symbol, return; } } + if (tvm::support::StartsWith(global_symbol, "TVMBackend")) { + return; + } + this->PrintFuncPrefix(fwd_decl_stream); this->PrintType(ret_type, fwd_decl_stream); fwd_decl_stream << " " << global_symbol << "("; diff --git a/src/target/source/codegen_c_host.h b/src/target/source/codegen_c_host.h index 3e013492efc2..c82f79589818 100644 --- a/src/target/source/codegen_c_host.h +++ b/src/target/source/codegen_c_host.h @@ -73,6 +73,21 @@ class CodeGenCHost : public CodeGenC { const Type& ret_type) override; Array GetFunctionNames() { return function_names_; } + protected: + /* \brief Names declared in external headers + * + * When encountering a `builtin::call_extern`, a forward declaration + * will usually be generated based on the arguments used in TIR. In + * some cases, this can conflict with the declaration used in the + * header file. For example, the `c_backend_api.h` header declares + * `void* TVMBackendStringRetValue(const char*)`, but the + * auto-generated declaration would have `uint8*` argument. + * + * Names in this set will be excluded from the automatic forward + * declaration, to avoid conflicting declarations. + */ + std::unordered_set included_function_names_; + private: /* \brief Internal structure to store information about function calls */ struct FunctionInfo { @@ -110,6 +125,9 @@ class CodeGenCHost : public CodeGenC { template inline void PrintTernaryCondExpr(const T* op, const char* compare, std::ostream& os); // NOLINT(*) + + void DeclareIncludeTVMRuntimeAPI(); + void DeclareIncludeTVMBackendAPI(); }; } // namespace codegen diff --git a/src/tir/transforms/generate_function_signature_metadata.cc b/src/tir/transforms/generate_function_signature_metadata.cc new file mode 100644 index 000000000000..637792afb901 --- /dev/null +++ b/src/tir/transforms/generate_function_signature_metadata.cc @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file generate_function_signature_metadata.cc + * \brief Split device function from host. + */ + +#define PICOJSON_USE_INT64 +#ifndef __STDC_FORMAT_MACROS +#define __STDC_FORMAT_MACROS +#endif +#include +#include +#include +#include +#include + +namespace tvm { +namespace tir { + +namespace { +picojson::value GenerateMetadata(const PrimFunc& func) { + std::vector params; + + return picojson::value(picojson::object({ + {"params", picojson::value(params)}, + })); +} + +picojson::value GenerateMetadata(const IRModule& mod) { + picojson::object functions; + for (const auto& [gvar, base_func] : mod->functions) { + bool is_externally_exposed = base_func->GetAttr(tvm::attr::kGlobalSymbol).defined(); + if (auto func = base_func.as(); func && is_externally_exposed) { + functions[gvar->name_hint] = GenerateMetadata(func.value()); + } + } + + return picojson::value(picojson::object({ + {"functions", picojson::value(functions)}, + })); +} + +std::string GenerateMetadataString(const IRModule& mod) { + return GenerateMetadata(mod).serialize(/* prettify = */ true); +} +} // namespace + +namespace transform { + +Pass GenerateFunctionSignatureMetadata() { + auto pass_func = [](IRModule mod, PassContext ctx) -> IRModule { + if (mod->ContainGlobalVar(runtime::symbol::tvm_get_tir_function_metadata)) { + return mod; + } + + std::string metadata = GenerateMetadataString(mod); + + Map func_attrs{ + {tvm::attr::kGlobalSymbol, String(runtime::symbol::tvm_get_tir_function_metadata)}, + {tvm::tir::attr::kIsHostFunc, Bool(true)}, + }; + Type ret_type = PrimType(DataType::Handle()); + PrimFunc metadata_func({}, Evaluate(ret(StringImm(metadata))), ret_type, {}, + DictAttrs(func_attrs)); + GlobalVar gvar(runtime::symbol::tvm_get_tir_function_metadata, FuncType({}, ret_type, {}, {})); + + mod.CopyOnWrite()->Add(gvar, metadata_func); + return mod; + }; + + return tvm::transform::CreateModulePass(pass_func, 0, "tir.GenerateFunctionSignatureMetadata", + {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.GenerateFunctionSignatureMetadata") + .set_body_typed(GenerateFunctionSignatureMetadata); + +} // namespace transform +} // namespace tir +} // namespace tvm diff --git a/tests/python/tir-transform/test_tir_transform_generate_function_signature_metadata.py b/tests/python/tir-transform/test_tir_transform_generate_function_signature_metadata.py new file mode 100644 index 000000000000..cfe8968e8631 --- /dev/null +++ b/tests/python/tir-transform/test_tir_transform_generate_function_signature_metadata.py @@ -0,0 +1,199 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import json + +import pytest + +import tvm.testing +from tvm.script import ir as I, tir as T + + +class Base: + def test_metadata(self): + """Validate the generated metadata string + + The metadata string should be valid JSON, which is parsed to + the same structure as defined by a test cases `expected`. + Comparison is done after parsing, so that the test case is + agnostic to any pretty-printing done for the JSON string. + """ + + mod = self.mod + mod = tvm.tir.transform.GenerateFunctionSignatureMetadata()(mod) + + func = mod[tvm.runtime.Module._GET_TIR_FUNCTION_METADATA] + metadata_str = func.body.value.args[0].value + + metadata = json.loads(metadata_str) + assert metadata == self.expected + + def test_metadata_function(self): + """Validate the PrimFunc containing the metadata string""" + + mod = self.mod + mod = tvm.tir.transform.GenerateFunctionSignatureMetadata()(mod) + func = mod[tvm.runtime.Module._GET_TIR_FUNCTION_METADATA] + metadata_str = func.body.value.args[0].value + + @T.prim_func + def expected() -> T.handle: + T.func_attr( + { + "global_symbol": tvm.runtime.Module._GET_TIR_FUNCTION_METADATA, + "tir.is_host_func": True, + } + ) + return metadata_str + + tvm.ir.assert_structural_equal(func, expected) + + def test_no_other_changes_to_module(self): + """Only change to IRModule is the new function + + All other functions are pass through unmodified. + """ + mod = self.mod + mod = tvm.tir.transform.GenerateFunctionSignatureMetadata()(mod) + del mod[tvm.runtime.Module._GET_TIR_FUNCTION_METADATA] + + tvm.ir.assert_structural_equal(self.mod, mod) + + +class TestEmptyModule(Base): + @property + def mod(self): + @I.ir_module + class Module: + pass + + return Module + + expected = { + "functions": {}, + } + + +class TestSingleFunction(Base): + @property + def mod(self): + @I.ir_module + class Module: + @T.prim_func + def main(): + pass + + return Module + + expected = { + "functions": { + "main": { + "params": [], + }, + }, + } + + +class TestMultipleFunctions(Base): + @property + def mod(self): + @I.ir_module + class Module: + @T.prim_func + def func_a(): + pass + + @T.prim_func + def func_b(): + pass + + return Module + + expected = { + "functions": { + "func_a": { + "params": [], + }, + "func_b": { + "params": [], + }, + }, + } + + +class TestPrivateFunction(Base): + """Private functions should not be exposed externally""" + + @property + def mod(self): + @I.ir_module + class Module: + @T.prim_func + def func_a(): + pass + + @T.prim_func(private=True) + def func_b(): + pass + + return Module + + expected = { + "functions": { + "func_a": { + "params": [], + }, + }, + } + + +@pytest.mark.xfail(reason="Not yet implemented") +class TestPrimitiveArguments(Base): + """Annotation of primitive arguments""" + + @property + def mod(self): + @I.ir_module + class Module: + @T.prim_func + def func_a(A: T.int32, B: T.float16): + pass + + return Module + + expected = { + "functions": { + "func_a": { + "params": [ + { + "name": "A", + "type": "PrimType", + "dtype": "int32", + }, + { + "name": "A", + "type": "PrimType", + "dtype": "float16", + }, + ], + }, + }, + } + + +if __name__ == "__main__": + tvm.testing.main()