diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index 4159c4b2e764..0d7f9f7e2b51 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -1145,6 +1145,31 @@ struct PackedFuncValueConverter { } \ } +#define TVM_MODULE_VTABLE_BEGIN(TypeKey) \ + const char* type_key() const final { return TypeKey; } \ + PackedFunc GetFunction(const String& _name, const ObjectPtr& _self) final { \ + using SelfPtr = std::remove_cv_t; +#define TVM_MODULE_VTABLE_END() \ + return PackedFunc(nullptr); \ + } +#define TVM_MODULE_VTABLE_ENTRY(Name, MemFunc) \ + if (_name == Name) { \ + return PackedFunc([_self](TVMArgs args, TVMRetValue* rv) -> void { \ + using Helper = ::tvm::runtime::detail::ModuleVTableEntryHelper; \ + SelfPtr self = static_cast(_self.get()); \ + CHECK_EQ(args.size(), Helper::LenArgs) \ + << "Function `" << self->type_key() << "::" << Name << "` requires " << Helper::LenArgs \ + << " arguments, but got " << args.size(); \ + Helper::Call(rv, self, MemFunc, args, Helper::IndexSeq{}); \ + }); \ + } +#define TVM_MODULE_VTABLE_ENTRY_PACKED(Name, Func) \ + if (_name == Name) { \ + auto f = (Func); \ + using FType = ::tvm::runtime::detail::function_signature::FType; \ + return TypedPackedFunc(std::move(f)).packed(); \ + } + /*! * \brief Export typed function as a PackedFunc * that can be loaded by LibraryModule. @@ -1330,6 +1355,61 @@ inline void for_each(const F& f, Args&&... args) { // NOLINT(*) for_each_dispatcher::run(f, std::forward(args)...); } +template +struct ModuleVTableEntryHelper {}; + +template +struct ModuleVTableEntryHelper { + using MemFnType = R (T::*)(Args...) const; + using IndexSeq = std::index_sequence_for; + static constexpr const std::size_t LenArgs = sizeof...(Args); + + template + static TVM_ALWAYS_INLINE void Call(TVMRetValue* rv, T* self, MemFnType f, TVMArgs args, + std::index_sequence) { + *rv = (self->*f)(args[Is]...); + } +}; + +template +struct ModuleVTableEntryHelper { + using MemFnType = R (T::*)(Args...); + using IndexSeq = std::index_sequence_for; + static constexpr const std::size_t LenArgs = sizeof...(Args); + + template + static TVM_ALWAYS_INLINE void Call(TVMRetValue* rv, T* self, MemFnType f, TVMArgs args, + std::index_sequence) { + *rv = (self->*f)(args[Is]...); + } +}; + +template +struct ModuleVTableEntryHelper { + using MemFnType = void (T::*)(Args...) const; + using IndexSeq = std::index_sequence_for; + static constexpr const std::size_t LenArgs = sizeof...(Args); + + template + static TVM_ALWAYS_INLINE void Call(TVMRetValue* rv, T* self, MemFnType f, TVMArgs args, + std::index_sequence) { + (self->*f)(args[Is]...); + } +}; + +template +struct ModuleVTableEntryHelper { + using MemFnType = void (T::*)(Args...); + using IndexSeq = std::index_sequence_for; + static constexpr const std::size_t LenArgs = sizeof...(Args); + + template + static TVM_ALWAYS_INLINE void Call(TVMRetValue* rv, T* self, MemFnType f, TVMArgs args, + std::index_sequence) { + (self->*f)(args[Is]...); + } +}; + namespace parameter_pack { template diff --git a/include/tvm/runtime/vm/executable.h b/include/tvm/runtime/vm/executable.h index d4872837b0c6..12bb115aa783 100644 --- a/include/tvm/runtime/vm/executable.h +++ b/include/tvm/runtime/vm/executable.h @@ -57,19 +57,28 @@ struct VMFunction; */ class TVM_DLL Executable : public ModuleNode { public: - /*! - * \brief Get a PackedFunc from an executable module. - * - * \param name the name of the function. - * \param sptr_to_self The shared_ptr that points to this module node. - * - * \return PackedFunc or nullptr when it is not available. - */ - PackedFunc GetFunction(const String& name, const ObjectPtr& sptr_to_self) final; + TVM_MODULE_VTABLE_BEGIN("VMExecutable"); + TVM_MODULE_VTABLE_ENTRY("get_lib", &Executable::GetLib); + TVM_MODULE_VTABLE_ENTRY("get_bytecode", &Executable::GetBytecode); + TVM_MODULE_VTABLE_ENTRY("get_constants", &Executable::GetConstants); + TVM_MODULE_VTABLE_ENTRY("get_virtual_devices", &Executable::GetVirtualDevices); + TVM_MODULE_VTABLE_ENTRY("get_primitives", &Executable::GetPrimitives); + TVM_MODULE_VTABLE_ENTRY("get_stats", &Executable::Stats); + TVM_MODULE_VTABLE_ENTRY("save", &Executable::Save); + TVM_MODULE_VTABLE_ENTRY("get_function_arity", &Executable::GetFunctionArity); + TVM_MODULE_VTABLE_ENTRY("get_function_param_name", &Executable::GetFunctionParameterName); + TVM_MODULE_VTABLE_ENTRY("vm_load_executable", &Executable::VMLoadExecutable); + TVM_MODULE_VTABLE_ENTRY("move_late_bound_consts", &Executable::MoveLateBoundConstantsToFile); + TVM_MODULE_VTABLE_ENTRY("get_late_bound_consts", &Executable::GetLateBoundConstants); + TVM_MODULE_VTABLE_ENTRY("load_late_bound_consts", &Executable::LoadLateBoundConstantsFromFile); + TVM_MODULE_VTABLE_ENTRY("load_late_bound_consts_from_map", + &Executable::LoadLateBoundConstantsFromMap); + TVM_MODULE_VTABLE_END(); /*! \brief Get the property of the runtime module .*/ int GetPropertyMask() const final { return ModulePropertyMask::kBinarySerializable; }; - + /*! \brief Creates a VM that loads `this` as the executable. */ + Module VMLoadExecutable(); /*! * \brief Write the Executable to the binary stream in serialized form. * @@ -123,17 +132,17 @@ class TVM_DLL Executable : public ModuleNode { * Must be called before \p SaveToBinary and friends if late-bound constants are * desired. Otherwise can be ignore. */ - void MoveLateBoundConstantsToStream(dmlc::Stream* stream, size_t byte_limit); + void MoveLateBoundConstantsToStream(dmlc::Stream* stream, int64_t byte_limit); /*! * \brief As for \p MoveLateBoundConstantsToStream, but save to file at \p path. */ - void MoveLateBoundConstantsToFile(const std::string& path, size_t byte_limit); + void MoveLateBoundConstantsToFile(const std::string& path, int64_t byte_limit); /*! * \brief Get a map of all constants with larger that byte_limit in size. */ - Map GetLateBoundConstants(size_t byte_limit); + Map GetLateBoundConstants(int64_t byte_limit); /*! * \brief Restores the late-bound constants for the executable (if any) from given byte-stream. @@ -255,12 +264,10 @@ class TVM_DLL Executable : public ModuleNode { * \param index Parameter index. * \return The parameter name. */ - std::string GetFunctionParameterName(std::string func, uint32_t index) const; + std::string GetFunctionParameterName(std::string func, int index) const; virtual ~Executable() {} - const char* type_key() const final { return "VMExecutable"; } - /*! * \brief The (compile-time, virtual) devices corresponding to each device index. * This vector contains a pair Device and its memory_scope. diff --git a/src/runtime/vm/executable.cc b/src/runtime/vm/executable.cc index 58c509f8d967..161c6dbfbd76 100644 --- a/src/runtime/vm/executable.cc +++ b/src/runtime/vm/executable.cc @@ -55,83 +55,24 @@ VMInstructionSerializer SerializeInstruction(const Instruction& instr); // Helper to deserialize a serialized vm instruction. Instruction DeserializeInstruction(const VMInstructionSerializer& instr); -PackedFunc Executable::GetFunction(const String& name, const ObjectPtr& sptr_to_self) { - if (name == "get_lib") { - return PackedFunc( - [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetLib(); }); - } else if (name == "get_bytecode") { - return PackedFunc( - [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetBytecode(); }); - } else if (name == "get_constants") { - return PackedFunc([this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetConstants(); }); - } else if (name == "get_virtual_devices") { - return PackedFunc([this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetVirtualDevices(); }); - } else if (name == "get_primitives") { - return PackedFunc([this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetPrimitives(); }); - } else if (name == "get_stats") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->Stats(); }); - } else if (name == "save") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->Save(); }); - } else if (name == "get_function_arity") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - std::string func_name = args[0]; - *rv = this->GetFunctionArity(func_name); - }); - } else if (name == "get_function_param_name") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - std::string func_name = args[0]; - int index = args[1]; - *rv = this->GetFunctionParameterName(func_name, index); - }); - } else if (name == "vm_load_executable") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - auto vm = make_object(); - ICHECK(sptr_to_self.get() == this); - vm->LoadExecutable(GetObjectPtr(this)); - *rv = Module(vm); - }); - } else if (name == "move_late_bound_consts") { - return PackedFunc([this](TVMArgs args, TVMRetValue* rv) { - CHECK_EQ(args.size(), 2); - std::string path = args[0]; - uint64_t byte_limit = args[1]; - MoveLateBoundConstantsToFile(path, static_cast(byte_limit)); - }); - } else if (name == "get_late_bound_consts") { - return PackedFunc([this](TVMArgs args, TVMRetValue* rv) { - CHECK_EQ(args.size(), 1); - uint64_t byte_limit = args[0]; - Map consts = GetLateBoundConstants(static_cast(byte_limit)); - *rv = consts; - }); - } else if (name == "load_late_bound_consts") { - return PackedFunc([this](TVMArgs args, TVMRetValue* rv) { - CHECK_EQ(args.size(), 1); - std::string path = args[0]; - LoadLateBoundConstantsFromFile(path); - }); - } else if (name == "load_late_bound_consts_from_map") { - return PackedFunc([this](TVMArgs args, TVMRetValue* rv) { - CHECK_EQ(args.size(), 1); - Map map = args[0]; - LoadLateBoundConstantsFromMap(map); - }); - } - return nullptr; -} - const VMFunction& Executable::GetVMFunctionWithName(const std::string& func_name) const { auto it = global_map.find(func_name); ICHECK(it != global_map.end()) << "Cannot find function " << func_name << " in executable"; return functions[it->second]; } +Module Executable::VMLoadExecutable() { + auto vm = make_object(); + vm->LoadExecutable(GetObjectPtr(this)); + return Module(vm); +} + int Executable::GetFunctionArity(std::string func_name) const { const auto& func = GetVMFunctionWithName(func_name); return func.params.size(); } -std::string Executable::GetFunctionParameterName(std::string func_name, uint32_t index) const { +std::string Executable::GetFunctionParameterName(std::string func_name, int index) const { const auto& func = GetVMFunctionWithName(func_name); ICHECK_LT(index, func.params.size()) << "Invalid parameter index"; return func.params[index]; @@ -311,7 +252,7 @@ void Executable::SaveVirtualDevicesSection(dmlc::Stream* strm) { strm->Write(host_device_index); } -Map Executable::GetLateBoundConstants(size_t byte_limit) { +Map Executable::GetLateBoundConstants(int64_t byte_limit) { ICHECK(late_bound_constant_names.empty()); late_bound_constant_names.reserve(constants.size()); Map map; @@ -319,7 +260,7 @@ Map Executable::GetLateBoundConstants(size_t byte_limit) { for (size_t const_index = 0; const_index < constants.size(); ++const_index) { const auto ndarray = Downcast(constants[const_index]); ICHECK(ndarray.defined()) << "Undefined constant at index " << const_index; - size_t num_bytes = runtime::GetDataSize(*ndarray.operator->()); + int64_t num_bytes = runtime::GetDataSize(*ndarray.operator->()); if (num_bytes < byte_limit) { // Leave as immediate. late_bound_constant_names.emplace_back(nullptr); @@ -337,12 +278,12 @@ Map Executable::GetLateBoundConstants(size_t byte_limit) { return map; } -void Executable::MoveLateBoundConstantsToStream(dmlc::Stream* stream, size_t byte_limit) { +void Executable::MoveLateBoundConstantsToStream(dmlc::Stream* stream, int64_t byte_limit) { Map map = GetLateBoundConstants(byte_limit); runtime::SaveParams(stream, map); } -void Executable::MoveLateBoundConstantsToFile(const std::string& path, size_t byte_limit) { +void Executable::MoveLateBoundConstantsToFile(const std::string& path, int64_t byte_limit) { tvm::runtime::SimpleBinaryFileStream stream(path, "wb"); MoveLateBoundConstantsToStream(&stream, byte_limit); }