diff --git a/include/tvm/runtime/c_backend_api.h b/include/tvm/runtime/c_backend_api.h index 8fde5948f993..82d420bd8b47 100644 --- a/include/tvm/runtime/c_backend_api.h +++ b/include/tvm/runtime/c_backend_api.h @@ -166,6 +166,18 @@ TVM_DLL int TVMBackendParallelBarrier(int task_id, TVMParallelGroupEnv* penv); */ TVM_DLL int TVMBackendRunOnce(void** handle, int (*f)(void*), void* cdata, int nbytes); +/*! \brief Generate a pointer usable as kTVMStr in a TVMRetValue + * + * While TVMArgValue uses `const char*` to represent string arguments, + * TVMRetValue represents string return values as a heap-allocated C++ + * `std::string` objects. + * + * \param c_str A null-terminated C-style string + * + * \return A pointer to the heap-allocated C++ `std::string`. + */ +void* TVMBackendStringRetValue(const char* c_str); + #ifdef __cplusplus } // TVM_EXTERN_C #endif diff --git a/include/tvm/runtime/module.h b/include/tvm/runtime/module.h index 308c936d624a..a3b85d517e1f 100644 --- a/include/tvm/runtime/module.h +++ b/include/tvm/runtime/module.h @@ -298,6 +298,8 @@ constexpr const char* tvm_module_main = "__tvm_main__"; constexpr const char* tvm_param_prefix = "__tvm_param__"; /*! \brief A PackedFunc that looks up linked parameters by storage_id. */ constexpr const char* tvm_lookup_linked_param = "_lookup_linked_param"; +/*! \brief Placeholder for the module's entry function. */ +constexpr const char* tvm_get_tir_function_metadata = "__tvm_get_tir_function_metadata__"; /*! \brief Model entrypoint generated as an interface to the AOT function outside of TIR */ constexpr const char* tvm_entrypoint_suffix = "run"; } // namespace symbol diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index 7266f8c4a50a..e27e9d59fd98 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -974,7 +974,8 @@ class TVMRetValue : public TVMPODValue_ { */ static TVMRetValue MoveFromCHost(TVMValue value, int type_code) { // Can move POD and everything under the object system. - ICHECK(type_code <= kTVMPackedFuncHandle || type_code == kTVMNDArrayHandle); + ICHECK(type_code <= kTVMPackedFuncHandle || type_code == kTVMNDArrayHandle || + type_code == kTVMStr); TVMRetValue ret; ret.value_ = value; ret.type_code_ = type_code; diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index 98edbeaceb26..2ef6f7c8c9d2 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -263,6 +263,13 @@ TVM_DLL Pass LowerCustomDatatypes(); */ TVM_DLL Pass DecorateDeviceScope(); +/*! + * \brief Generate module metadata describing function signatures + * + * \return The pass. + */ +TVM_DLL Pass GenerateFunctionSignatureMetadata(); + /*! * \brief Annotate locations that should be run on the device * diff --git a/python/tvm/runtime/module.py b/python/tvm/runtime/module.py index 2c3eff700009..b432022e8f45 100644 --- a/python/tvm/runtime/module.py +++ b/python/tvm/runtime/module.py @@ -17,10 +17,15 @@ # pylint: disable=invalid-name, unused-import, import-outside-toplevel, inconsistent-return-statements """Runtime Module namespace.""" -import os + import ctypes +import difflib +import json +import os import struct -from typing import Sequence + +from typing import Sequence, Iterable, Iterator, Optional + import numpy as np from tvm._ffi.base import _LIB, check_call, c_str, string_types, _RUNTIME_ONLY @@ -100,7 +105,20 @@ class ModulePropertyMask(object): class Module(object): """Runtime Module.""" - __slots__ = ["handle", "_entry", "entry_name"] + __slots__ = ["handle", "_entry", "entry_name", "_cached_metadata"] + + _GET_TIR_FUNCTION_METADATA = "__tvm_get_tir_function_metadata__" + """The function to retrieve TIR metadata, if any + + This is for internal use only. If present in the module, this + should be a function that returns a JSON string describing the + functions provided by the module, along with the signatures of + those functions. + + This parameter corresponds to the C++ value + `tvm::runtime::symbol::tvm_get_tir_function_metadata`, located in + `#include `. + """ def __init__(self, handle): self.handle = handle @@ -163,18 +181,54 @@ def get_function(self, name, query_imports=False): Returns ------- - f : tvm.runtime.PackedFunc + func : tvm.runtime.PackedFunc The result function. """ + func = self._get_function(name, query_imports=query_imports) + if func is None: + nearby_names = difflib.get_close_matches(name, self.keys()) + if nearby_names: + message = ( + f"Module has no function '{name}'. " + f"The module does contain functions with similar names: " + f"{nearby_names}." + ) + else: + message = ( + f"Module has no function '{name}'. " + f"The module does not contain any function with a similar name." + ) + raise KeyError(message) + + return func + + def _get_function(self, name, query_imports=False) -> Optional[PackedFunc]: + """Get function from the module. + + Parameters + ---------- + name : str + The name of the function + + query_imports : bool + Whether also query modules imported by this module. + + Returns + ------- + func : Optional[tvm.runtime.PackedFunc] + The result function, or None if it cannot be found + """ ret_handle = PackedFuncHandle() check_call( _LIB.TVMModGetFunction( self.handle, c_str(name), ctypes.c_int(query_imports), ctypes.byref(ret_handle) ) ) - if not ret_handle.value: - raise AttributeError(f"Module has no function '{name}'") - return PackedFunc(ret_handle, False) + # pylint: disable=using-constant-test + if ret_handle.value: + return PackedFunc(ret_handle, False) + else: + return None def import_module(self, module): """Add module to the import list of current one. @@ -186,11 +240,59 @@ def import_module(self, module): """ check_call(_LIB.TVMModImport(self.handle, module.handle)) - def __getitem__(self, name): + def __getitem__(self, name: str) -> PackedFunc: + """Return the PackedFunc associated with the gven name + + Parameters + ---------- + name: str + The name of the function to be returned + + Returns + ------- + PackedFunc + + """ + if not isinstance(name, string_types): - raise ValueError("Can only take string as function name") + raise TypeError(f"Module.__getitem__ expects a string, but received {type(name)}") return self.get_function(name) + def __contains__(self, key: str) -> bool: + return key in self.keys() + + @property + def _metadata(self): + if hasattr(self, "_cached_metadata"): + # pylint: disable=access-member-before-definition + return self._cached_metadata + + metadata_func = self._get_function(Module._GET_TIR_FUNCTION_METADATA) + if metadata_func is None: + raise RuntimeError("Cannot find function metadata in runtime.Module") + + self._cached_metadata = json.loads(metadata_func()) + return self._cached_metadata + + def keys(self) -> Sequence[str]: + """Return a list of functions in the module + + Returns + ------- + Sequence[str] + The functions in the module + """ + for function in self._metadata["functions"]: + yield function + + def values(self) -> Iterator[PackedFunc]: + for key in self.keys(): + yield self[key] + + def items(self) -> Iterator[PackedFunc]: + for key in self.keys(): + yield key, self[key] + def __eq__(self, other): return self.handle.value == other.handle.value diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index 8816880e7b52..bce016d235dc 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -1639,6 +1639,9 @@ def ret(val): """ val = convert(val) + if isinstance(val, tvm.runtime.String): + val = StringImm(val) + return call_intrin(val.dtype, "tir.ret", val) diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index c2022b918643..6ffd6974f30c 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -488,6 +488,22 @@ def MakeUnpackedAPI(): return _ffi_api.MakeUnpackedAPI() # type: ignore +def GenerateFunctionSignatureMetadata(): + """Generate metadata describing the function signatures + + Generate a metadata function that returns a JSON-formatted string, + describing the functions available within the IRModule and their + signatures. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + + """ + return _ffi_api.GenerateFunctionSignatureMetadata() # type: ignore + + def AnnotateDeviceRegions(): """Annotate locations that should be run on the device diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 7ea5032fa0cc..4fcd86d85417 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -569,8 +569,22 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target) Array mixed_pass_list; - // FPComputeLegalize uses the target attrs added by BindTarget, so it must come first + // AnnotateEntryFunc inspects user-defined functions to provide a + // default function to call. Therefore, it should appear before any + // passes that would generate additional PrimFuncs. + mixed_pass_list.push_back(tir::transform::AnnotateEntryFunc()); + + // GenerateFunctionSignatureMetadata produces a function with + // `tir.is_host_func`, relying on `BindTarget` to generate the + // correct target annotation. Therefore, it must come before + // BindTarget. + mixed_pass_list.push_back(tir::transform::GenerateFunctionSignatureMetadata()); + + // Many later passes, such as FP8ComputeLegalize, use the target + // attrs added by BindTarget. Therefore, BindTarget should occur as + // early as possible. mixed_pass_list.push_back(tir::transform::BindTarget(target)); + mixed_pass_list.push_back(tir::transform::FP8ComputeLegalize()); // VerifyVTCMLimit must occur before LowerVtcmAlloc @@ -580,8 +594,6 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target) mixed_pass_list.push_back(tir::transform::VerifyMemory()); - mixed_pass_list.push_back(tir::transform::AnnotateEntryFunc()); - bool detect_global_barrier = pass_ctx->GetConfig("tir.detect_global_barrier", Bool(false)).value(); if (detect_global_barrier) { diff --git a/src/runtime/c_runtime_api.cc b/src/runtime/c_runtime_api.cc index ea22b89dd771..11c178b110f6 100644 --- a/src/runtime/c_runtime_api.cc +++ b/src/runtime/c_runtime_api.cc @@ -555,6 +555,12 @@ int TVMBackendRunOnce(void** handle, int (*f)(void*), void* cdata, int nbytes) { return 0; } +void* TVMBackendStringRetValue(const char* c_str) { + // The `std::string` allocated here is immediately stored in a + // `TVMRetValue`, which takes ownership of the object. + return new std::string(c_str); +} + int TVMFuncFree(TVMFunctionHandle func) { return TVMObjectFree(func); } int TVMByteArrayFree(TVMByteArray* arr) { diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index 009fc1672ace..dcb870fad30e 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -451,7 +451,28 @@ void CodeGenC::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NOLINT PrintConst(op, os, this); } void CodeGenC::VisitExpr_(const StringImmNode* op, std::ostream& os) { // NOLINT(*) - os << "\"" << op->value << "\""; + const auto& str = op->value; + os << '"'; + for (size_t i = 0; i < str.size(); i++) { + char c = str.c_str()[i]; + switch (c) { + case '\n': + os << "\\n"; + break; + + case '\\': + case '"': + case '?': + os << '\\' << c; + break; + + default: + os << c; + break; + } + } + + os << '"'; } template diff --git a/src/target/source/codegen_c_host.cc b/src/target/source/codegen_c_host.cc index b22d32d6c5e3..2846108a87fb 100644 --- a/src/target/source/codegen_c_host.cc +++ b/src/target/source/codegen_c_host.cc @@ -35,6 +35,7 @@ #include #include "../../support/str_escape.h" +#include "../../support/utils.h" #include "../build_common.h" #include "../func_registry_generator.h" #include "codegen_params.h" @@ -51,8 +52,8 @@ void CodeGenCHost::Init(bool output_ssa, bool emit_asserts, bool emit_fwd_func_d declared_globals_.clear(); decl_stream << "// tvm target: " << target_str << "\n"; decl_stream << "#define TVM_EXPORTS\n"; - decl_stream << "#include \"tvm/runtime/c_runtime_api.h\"\n"; - decl_stream << "#include \"tvm/runtime/c_backend_api.h\"\n"; + DeclareIncludeTVMRuntimeAPI(); + DeclareIncludeTVMBackendAPI(); decl_stream << "#include \n"; decl_stream << "#include \n"; if (devices.find("ethos-u") != devices.end()) { @@ -69,6 +70,66 @@ void CodeGenCHost::Init(bool output_ssa, bool emit_asserts, bool emit_fwd_func_d CodeGenC::Init(output_ssa); } +void CodeGenCHost::DeclareIncludeTVMRuntimeAPI() { + decl_stream << "#include \"tvm/runtime/c_runtime_api.h\"\n"; + included_function_names_.insert("TVMAPISetLastError"); + included_function_names_.insert("TVMAPISetLastPythonError"); + included_function_names_.insert("TVMGetLastPythonError"); + included_function_names_.insert("TVMGetLastError"); + included_function_names_.insert("TVMGetLastBacktrace"); + included_function_names_.insert("TVMDropLastPythonError"); + included_function_names_.insert("TVMThrowLastError"); + included_function_names_.insert("TVMModLoadFromFile"); + included_function_names_.insert("TVMModImport"); + included_function_names_.insert("TVMModGetFunction"); + included_function_names_.insert("TVMModFree"); + included_function_names_.insert("TVMFuncFree"); + included_function_names_.insert("TVMFuncCall"); + included_function_names_.insert("TVMCFuncSetReturn"); + included_function_names_.insert("TVMCbArgToReturn"); + included_function_names_.insert("TVMFuncCreateFromCFunc"); + included_function_names_.insert("TVMFuncRegisterGlobal"); + included_function_names_.insert("TVMFuncGetGlobal"); + included_function_names_.insert("TVMFuncListGlobalNames"); + included_function_names_.insert("TVMFuncRemoveGlobal"); + included_function_names_.insert("TVMArrayAlloc"); + included_function_names_.insert("TVMArrayFree"); + included_function_names_.insert("TVMArrayCopyFromBytes"); + included_function_names_.insert("TVMArrayCopyToBytes"); + included_function_names_.insert("TVMArrayCopyFromTo"); + included_function_names_.insert("TVMArrayFromDLPack"); + included_function_names_.insert("TVMArrayToDLPack"); + included_function_names_.insert("TVMDLManagedTensorCallDeleter"); + included_function_names_.insert("TVMStreamCreate"); + included_function_names_.insert("TVMStreamFree"); + included_function_names_.insert("TVMSetStream"); + included_function_names_.insert("TVMSynchronize"); + included_function_names_.insert("TVMStreamStreamSynchronize"); + included_function_names_.insert("TVMObjectGetTypeIndex"); + included_function_names_.insert("TVMObjectTypeKey2Index"); + included_function_names_.insert("TVMObjectTypeIndex2Key"); + included_function_names_.insert("TVMObjectRetain"); + included_function_names_.insert("TVMObjectFree"); + included_function_names_.insert("TVMByteArrayFree"); + included_function_names_.insert("TVMDeviceAllocDataSpace"); + included_function_names_.insert("TVMDeviceAllocDataSpaceWithScope"); + included_function_names_.insert("TVMDeviceFreeDataSpace"); + included_function_names_.insert("TVMDeviceCopyDataFromTo"); + included_function_names_.insert("TVMObjectDerivedFrom"); +} + +void CodeGenCHost::DeclareIncludeTVMBackendAPI() { + decl_stream << "#include \"tvm/runtime/c_backend_api.h\"\n"; + included_function_names_.insert("TVMBackendGetFuncFromEnv"); + included_function_names_.insert("TVMBackendRegisterSystemLibSymbol"); + included_function_names_.insert("TVMBackendAllocWorkspace"); + included_function_names_.insert("TVMBackendFreeWorkspace"); + included_function_names_.insert("TVMBackendRegisterEnvCAPI"); + included_function_names_.insert("TVMBackendParallelLaunch"); + included_function_names_.insert("TVMBackendParallelBarrier"); + included_function_names_.insert("TVMBackendRunOnce"); +} + void CodeGenCHost::InitGlobalContext() { decl_stream << "void* " << tvm::runtime::symbol::tvm_module_ctx << " = NULL;\n"; } @@ -118,6 +179,10 @@ void CodeGenCHost::GenerateForwardFunctionDeclarations(String global_symbol, return; } } + if (tvm::support::StartsWith(global_symbol, "TVMBackend")) { + return; + } + this->PrintFuncPrefix(fwd_decl_stream); this->PrintType(ret_type, fwd_decl_stream); fwd_decl_stream << " " << global_symbol << "("; diff --git a/src/target/source/codegen_c_host.h b/src/target/source/codegen_c_host.h index 3e013492efc2..c82f79589818 100644 --- a/src/target/source/codegen_c_host.h +++ b/src/target/source/codegen_c_host.h @@ -73,6 +73,21 @@ class CodeGenCHost : public CodeGenC { const Type& ret_type) override; Array GetFunctionNames() { return function_names_; } + protected: + /* \brief Names declared in external headers + * + * When encountering a `builtin::call_extern`, a forward declaration + * will usually be generated based on the arguments used in TIR. In + * some cases, this can conflict with the declaration used in the + * header file. For example, the `c_backend_api.h` header declares + * `void* TVMBackendStringRetValue(const char*)`, but the + * auto-generated declaration would have `uint8*` argument. + * + * Names in this set will be excluded from the automatic forward + * declaration, to avoid conflicting declarations. + */ + std::unordered_set included_function_names_; + private: /* \brief Internal structure to store information about function calls */ struct FunctionInfo { @@ -110,6 +125,9 @@ class CodeGenCHost : public CodeGenC { template inline void PrintTernaryCondExpr(const T* op, const char* compare, std::ostream& os); // NOLINT(*) + + void DeclareIncludeTVMRuntimeAPI(); + void DeclareIncludeTVMBackendAPI(); }; } // namespace codegen diff --git a/src/tir/transforms/generate_function_signature_metadata.cc b/src/tir/transforms/generate_function_signature_metadata.cc new file mode 100644 index 000000000000..637792afb901 --- /dev/null +++ b/src/tir/transforms/generate_function_signature_metadata.cc @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file generate_function_signature_metadata.cc + * \brief Split device function from host. + */ + +#define PICOJSON_USE_INT64 +#ifndef __STDC_FORMAT_MACROS +#define __STDC_FORMAT_MACROS +#endif +#include +#include +#include +#include +#include + +namespace tvm { +namespace tir { + +namespace { +picojson::value GenerateMetadata(const PrimFunc& func) { + std::vector params; + + return picojson::value(picojson::object({ + {"params", picojson::value(params)}, + })); +} + +picojson::value GenerateMetadata(const IRModule& mod) { + picojson::object functions; + for (const auto& [gvar, base_func] : mod->functions) { + bool is_externally_exposed = base_func->GetAttr(tvm::attr::kGlobalSymbol).defined(); + if (auto func = base_func.as(); func && is_externally_exposed) { + functions[gvar->name_hint] = GenerateMetadata(func.value()); + } + } + + return picojson::value(picojson::object({ + {"functions", picojson::value(functions)}, + })); +} + +std::string GenerateMetadataString(const IRModule& mod) { + return GenerateMetadata(mod).serialize(/* prettify = */ true); +} +} // namespace + +namespace transform { + +Pass GenerateFunctionSignatureMetadata() { + auto pass_func = [](IRModule mod, PassContext ctx) -> IRModule { + if (mod->ContainGlobalVar(runtime::symbol::tvm_get_tir_function_metadata)) { + return mod; + } + + std::string metadata = GenerateMetadataString(mod); + + Map func_attrs{ + {tvm::attr::kGlobalSymbol, String(runtime::symbol::tvm_get_tir_function_metadata)}, + {tvm::tir::attr::kIsHostFunc, Bool(true)}, + }; + Type ret_type = PrimType(DataType::Handle()); + PrimFunc metadata_func({}, Evaluate(ret(StringImm(metadata))), ret_type, {}, + DictAttrs(func_attrs)); + GlobalVar gvar(runtime::symbol::tvm_get_tir_function_metadata, FuncType({}, ret_type, {}, {})); + + mod.CopyOnWrite()->Add(gvar, metadata_func); + return mod; + }; + + return tvm::transform::CreateModulePass(pass_func, 0, "tir.GenerateFunctionSignatureMetadata", + {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.GenerateFunctionSignatureMetadata") + .set_body_typed(GenerateFunctionSignatureMetadata); + +} // namespace transform +} // namespace tir +} // namespace tvm diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index bf1f3a9e7fd2..cc7c14fde5e1 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -90,6 +90,10 @@ class ReturnRewriter : public StmtMutator { } else if (dtype.is_void()) { info.tcode = kTVMNullptr; info.expr = val; + } else if (val->IsInstance()) { + info.tcode = kTVMStr; + info.expr = Call(DataType::Handle(), builtin::call_pure_extern(), + {StringImm("TVMBackendStringRetValue"), val}); } else { LOG(FATAL) << "data type " << dtype << " not supported yet"; } diff --git a/tests/python/codegen/test_target_codegen_llvm.py b/tests/python/codegen/test_target_codegen_llvm.py index f1316ae3cee0..a9694030e5d7 100644 --- a/tests/python/codegen/test_target_codegen_llvm.py +++ b/tests/python/codegen/test_target_codegen_llvm.py @@ -14,14 +14,19 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + import collections import ctypes import json import math -import numpy as np -import pytest +import pathlib import re import sys +import tempfile + +import numpy as np +import pytest + import tvm import tvm.testing @@ -1109,5 +1114,104 @@ def func(): built = tvm.build(func, target="llvm") +save_and_reload = tvm.testing.parameter( + by_dict={ + "llvm_module": False, + "library_module": True, + } +) + + +def _save_and_reload(mod): + with tempfile.TemporaryDirectory() as temp_dir: + temp_dir = pathlib.Path(temp_dir) + model_path = temp_dir.joinpath("lib.so") + mod.export_library(model_path) + mod = tvm.runtime.load_module(model_path) + return mod + + +@tvm.testing.requires_llvm +def test_return_string(save_and_reload: bool): + @I.ir_module + class mod: + @T.prim_func + def main(): + return "my string here" + + mod = tvm.build(mod) + + if save_and_reload: + mod = _save_and_reload(mod) + + res = mod() + assert res == "my string here" + + +@tvm.testing.requires_llvm +def test_inspect_module_contents(save_and_reload: bool): + @I.ir_module + class mod: + @T.prim_func + def main(A: T.Buffer(1, dtype="float32")): + mod.subroutine(A.data) + + @T.prim_func(private=True) + def subroutine(A_data: T.handle("float32")): + A = T.decl_buffer(1, dtype="float32", data=A_data) + A[0] = 42.0 + + @T.prim_func + def main_2(A: T.Buffer(1, dtype="float32")): + mod.subroutine(A.data) + + mod = tvm.build(mod, target="llvm") + + if save_and_reload: + mod = _save_and_reload(mod) + + assert "main" in mod + assert "main_2" in mod + assert "subroutine" not in mod + + assert set(mod.keys()) == set(["main", "main_2"]) + + +@tvm.testing.requires_llvm +def test_suggest_nearby_name(save_and_reload: bool): + @I.ir_module + class mod: + @T.prim_func + def main(): + pass + + @T.prim_func + def long_tedious_name_that_would_be_easy_to_misspell(): + pass + + mod = tvm.build(mod, target="llvm") + + if save_and_reload: + mod = _save_and_reload(mod) + + mod["main"] + mod["long_tedious_name_that_would_be_easy_to_misspell"] + + with pytest.raises(KeyError, match=r"similar names: \['main'\]"): + mod["mian"] + + with pytest.raises( + KeyError, + match=r"similar names: \['long_tedious_name_that_would_be_easy_to_misspell'\]", + ): + mod["long_tedious_name_that_would_be_easy_to_mispell"] + + with pytest.raises( + KeyError, + match=r"does not contain any function with a similar name", + ): + mod["unrelated_name_that_is_different_from_any_valid_name"] + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/tir-transform/test_tir_transform_generate_function_signature_metadata.py b/tests/python/tir-transform/test_tir_transform_generate_function_signature_metadata.py new file mode 100644 index 000000000000..cfe8968e8631 --- /dev/null +++ b/tests/python/tir-transform/test_tir_transform_generate_function_signature_metadata.py @@ -0,0 +1,199 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import json + +import pytest + +import tvm.testing +from tvm.script import ir as I, tir as T + + +class Base: + def test_metadata(self): + """Validate the generated metadata string + + The metadata string should be valid JSON, which is parsed to + the same structure as defined by a test cases `expected`. + Comparison is done after parsing, so that the test case is + agnostic to any pretty-printing done for the JSON string. + """ + + mod = self.mod + mod = tvm.tir.transform.GenerateFunctionSignatureMetadata()(mod) + + func = mod[tvm.runtime.Module._GET_TIR_FUNCTION_METADATA] + metadata_str = func.body.value.args[0].value + + metadata = json.loads(metadata_str) + assert metadata == self.expected + + def test_metadata_function(self): + """Validate the PrimFunc containing the metadata string""" + + mod = self.mod + mod = tvm.tir.transform.GenerateFunctionSignatureMetadata()(mod) + func = mod[tvm.runtime.Module._GET_TIR_FUNCTION_METADATA] + metadata_str = func.body.value.args[0].value + + @T.prim_func + def expected() -> T.handle: + T.func_attr( + { + "global_symbol": tvm.runtime.Module._GET_TIR_FUNCTION_METADATA, + "tir.is_host_func": True, + } + ) + return metadata_str + + tvm.ir.assert_structural_equal(func, expected) + + def test_no_other_changes_to_module(self): + """Only change to IRModule is the new function + + All other functions are pass through unmodified. + """ + mod = self.mod + mod = tvm.tir.transform.GenerateFunctionSignatureMetadata()(mod) + del mod[tvm.runtime.Module._GET_TIR_FUNCTION_METADATA] + + tvm.ir.assert_structural_equal(self.mod, mod) + + +class TestEmptyModule(Base): + @property + def mod(self): + @I.ir_module + class Module: + pass + + return Module + + expected = { + "functions": {}, + } + + +class TestSingleFunction(Base): + @property + def mod(self): + @I.ir_module + class Module: + @T.prim_func + def main(): + pass + + return Module + + expected = { + "functions": { + "main": { + "params": [], + }, + }, + } + + +class TestMultipleFunctions(Base): + @property + def mod(self): + @I.ir_module + class Module: + @T.prim_func + def func_a(): + pass + + @T.prim_func + def func_b(): + pass + + return Module + + expected = { + "functions": { + "func_a": { + "params": [], + }, + "func_b": { + "params": [], + }, + }, + } + + +class TestPrivateFunction(Base): + """Private functions should not be exposed externally""" + + @property + def mod(self): + @I.ir_module + class Module: + @T.prim_func + def func_a(): + pass + + @T.prim_func(private=True) + def func_b(): + pass + + return Module + + expected = { + "functions": { + "func_a": { + "params": [], + }, + }, + } + + +@pytest.mark.xfail(reason="Not yet implemented") +class TestPrimitiveArguments(Base): + """Annotation of primitive arguments""" + + @property + def mod(self): + @I.ir_module + class Module: + @T.prim_func + def func_a(A: T.int32, B: T.float16): + pass + + return Module + + expected = { + "functions": { + "func_a": { + "params": [ + { + "name": "A", + "type": "PrimType", + "dtype": "int32", + }, + { + "name": "A", + "type": "PrimType", + "dtype": "float16", + }, + ], + }, + }, + } + + +if __name__ == "__main__": + tvm.testing.main()