Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions include/tvm/runtime/packed_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -1145,6 +1145,31 @@ struct PackedFuncValueConverter {
} \
}

#define TVM_MODULE_VTABLE_BEGIN(TypeKey) \
const char* type_key() const final { return TypeKey; } \
PackedFunc GetFunction(const String& _name, const ObjectPtr<Object>& _self) final { \
using SelfPtr = std::remove_cv_t<decltype(this)>;
#define TVM_MODULE_VTABLE_END() \
return PackedFunc(nullptr); \
}
#define TVM_MODULE_VTABLE_ENTRY(Name, MemFunc) \
if (_name == Name) { \
return PackedFunc([_self](TVMArgs args, TVMRetValue* rv) -> void { \
using Helper = ::tvm::runtime::detail::ModuleVTableEntryHelper<decltype(MemFunc)>; \
SelfPtr self = static_cast<SelfPtr>(_self.get()); \
CHECK_EQ(args.size(), Helper::LenArgs) \
<< "Function `" << self->type_key() << "::" << Name << "` requires " << Helper::LenArgs \
<< " arguments, but got " << args.size(); \
Helper::Call(rv, self, MemFunc, args, Helper::IndexSeq{}); \
}); \
}
#define TVM_MODULE_VTABLE_ENTRY_PACKED(Name, Func) \
if (_name == Name) { \
auto f = (Func); \
using FType = ::tvm::runtime::detail::function_signature<decltype(f)>::FType; \
return TypedPackedFunc<FType>(std::move(f)).packed(); \
}

/*!
* \brief Export typed function as a PackedFunc
* that can be loaded by LibraryModule.
Expand Down Expand Up @@ -1330,6 +1355,61 @@ inline void for_each(const F& f, Args&&... args) { // NOLINT(*)
for_each_dispatcher<sizeof...(Args) == 0, 0, F>::run(f, std::forward<Args>(args)...);
}

template <typename T>
struct ModuleVTableEntryHelper {};

template <typename T, typename R, typename... Args>
struct ModuleVTableEntryHelper<R (T::*)(Args...) const> {
using MemFnType = R (T::*)(Args...) const;
using IndexSeq = std::index_sequence_for<Args...>;
static constexpr const std::size_t LenArgs = sizeof...(Args);

template <std::size_t... Is>
static TVM_ALWAYS_INLINE void Call(TVMRetValue* rv, T* self, MemFnType f, TVMArgs args,
std::index_sequence<Is...>) {
*rv = (self->*f)(args[Is]...);
}
};

template <typename T, typename R, typename... Args>
struct ModuleVTableEntryHelper<R (T::*)(Args...)> {
using MemFnType = R (T::*)(Args...);
using IndexSeq = std::index_sequence_for<Args...>;
static constexpr const std::size_t LenArgs = sizeof...(Args);

template <std::size_t... Is>
static TVM_ALWAYS_INLINE void Call(TVMRetValue* rv, T* self, MemFnType f, TVMArgs args,
std::index_sequence<Is...>) {
*rv = (self->*f)(args[Is]...);
}
};

template <typename T, typename... Args>
struct ModuleVTableEntryHelper<void (T::*)(Args...) const> {
using MemFnType = void (T::*)(Args...) const;
using IndexSeq = std::index_sequence_for<Args...>;
static constexpr const std::size_t LenArgs = sizeof...(Args);

template <std::size_t... Is>
static TVM_ALWAYS_INLINE void Call(TVMRetValue* rv, T* self, MemFnType f, TVMArgs args,
std::index_sequence<Is...>) {
(self->*f)(args[Is]...);
}
};

template <typename T, typename... Args>
struct ModuleVTableEntryHelper<void (T::*)(Args...)> {
using MemFnType = void (T::*)(Args...);
using IndexSeq = std::index_sequence_for<Args...>;
static constexpr const std::size_t LenArgs = sizeof...(Args);

template <std::size_t... Is>
static TVM_ALWAYS_INLINE void Call(TVMRetValue* rv, T* self, MemFnType f, TVMArgs args,
std::index_sequence<Is...>) {
(self->*f)(args[Is]...);
}
};

namespace parameter_pack {

template <typename... EnumArgs>
Expand Down
39 changes: 23 additions & 16 deletions include/tvm/runtime/vm/executable.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,19 +57,28 @@ struct VMFunction;
*/
class TVM_DLL Executable : public ModuleNode {
public:
/*!
* \brief Get a PackedFunc from an executable module.
*
* \param name the name of the function.
* \param sptr_to_self The shared_ptr that points to this module node.
*
* \return PackedFunc or nullptr when it is not available.
*/
PackedFunc GetFunction(const String& name, const ObjectPtr<Object>& sptr_to_self) final;
TVM_MODULE_VTABLE_BEGIN("VMExecutable");
TVM_MODULE_VTABLE_ENTRY("get_lib", &Executable::GetLib);
TVM_MODULE_VTABLE_ENTRY("get_bytecode", &Executable::GetBytecode);
TVM_MODULE_VTABLE_ENTRY("get_constants", &Executable::GetConstants);
TVM_MODULE_VTABLE_ENTRY("get_virtual_devices", &Executable::GetVirtualDevices);
TVM_MODULE_VTABLE_ENTRY("get_primitives", &Executable::GetPrimitives);
TVM_MODULE_VTABLE_ENTRY("get_stats", &Executable::Stats);
TVM_MODULE_VTABLE_ENTRY("save", &Executable::Save);
TVM_MODULE_VTABLE_ENTRY("get_function_arity", &Executable::GetFunctionArity);
TVM_MODULE_VTABLE_ENTRY("get_function_param_name", &Executable::GetFunctionParameterName);
TVM_MODULE_VTABLE_ENTRY("vm_load_executable", &Executable::VMLoadExecutable);
TVM_MODULE_VTABLE_ENTRY("move_late_bound_consts", &Executable::MoveLateBoundConstantsToFile);
TVM_MODULE_VTABLE_ENTRY("get_late_bound_consts", &Executable::GetLateBoundConstants);
TVM_MODULE_VTABLE_ENTRY("load_late_bound_consts", &Executable::LoadLateBoundConstantsFromFile);
TVM_MODULE_VTABLE_ENTRY("load_late_bound_consts_from_map",
&Executable::LoadLateBoundConstantsFromMap);
TVM_MODULE_VTABLE_END();

/*! \brief Get the property of the runtime module .*/
int GetPropertyMask() const final { return ModulePropertyMask::kBinarySerializable; };

/*! \brief Creates a VM that loads `this` as the executable. */
Module VMLoadExecutable();
/*!
* \brief Write the Executable to the binary stream in serialized form.
*
Expand Down Expand Up @@ -123,17 +132,17 @@ class TVM_DLL Executable : public ModuleNode {
* Must be called before \p SaveToBinary and friends if late-bound constants are
* desired. Otherwise can be ignore.
*/
void MoveLateBoundConstantsToStream(dmlc::Stream* stream, size_t byte_limit);
void MoveLateBoundConstantsToStream(dmlc::Stream* stream, int64_t byte_limit);

/*!
* \brief As for \p MoveLateBoundConstantsToStream, but save to file at \p path.
*/
void MoveLateBoundConstantsToFile(const std::string& path, size_t byte_limit);
void MoveLateBoundConstantsToFile(const std::string& path, int64_t byte_limit);

/*!
* \brief Get a map of all constants with larger that byte_limit in size.
*/
Map<String, NDArray> GetLateBoundConstants(size_t byte_limit);
Map<String, NDArray> GetLateBoundConstants(int64_t byte_limit);

/*!
* \brief Restores the late-bound constants for the executable (if any) from given byte-stream.
Expand Down Expand Up @@ -255,12 +264,10 @@ class TVM_DLL Executable : public ModuleNode {
* \param index Parameter index.
* \return The parameter name.
*/
std::string GetFunctionParameterName(std::string func, uint32_t index) const;
std::string GetFunctionParameterName(std::string func, int index) const;

virtual ~Executable() {}

const char* type_key() const final { return "VMExecutable"; }

/*!
* \brief The (compile-time, virtual) devices corresponding to each device index.
* This vector contains a pair Device and its memory_scope.
Expand Down
81 changes: 11 additions & 70 deletions src/runtime/vm/executable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,83 +55,24 @@ VMInstructionSerializer SerializeInstruction(const Instruction& instr);
// Helper to deserialize a serialized vm instruction.
Instruction DeserializeInstruction(const VMInstructionSerializer& instr);

PackedFunc Executable::GetFunction(const String& name, const ObjectPtr<Object>& sptr_to_self) {
if (name == "get_lib") {
return PackedFunc(
[sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetLib(); });
} else if (name == "get_bytecode") {
return PackedFunc(
[sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetBytecode(); });
} else if (name == "get_constants") {
return PackedFunc([this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetConstants(); });
} else if (name == "get_virtual_devices") {
return PackedFunc([this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetVirtualDevices(); });
} else if (name == "get_primitives") {
return PackedFunc([this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetPrimitives(); });
} else if (name == "get_stats") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->Stats(); });
} else if (name == "save") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->Save(); });
} else if (name == "get_function_arity") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
std::string func_name = args[0];
*rv = this->GetFunctionArity(func_name);
});
} else if (name == "get_function_param_name") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
std::string func_name = args[0];
int index = args[1];
*rv = this->GetFunctionParameterName(func_name, index);
});
} else if (name == "vm_load_executable") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
auto vm = make_object<VirtualMachine>();
ICHECK(sptr_to_self.get() == this);
vm->LoadExecutable(GetObjectPtr<Executable>(this));
*rv = Module(vm);
});
} else if (name == "move_late_bound_consts") {
return PackedFunc([this](TVMArgs args, TVMRetValue* rv) {
CHECK_EQ(args.size(), 2);
std::string path = args[0];
uint64_t byte_limit = args[1];
MoveLateBoundConstantsToFile(path, static_cast<size_t>(byte_limit));
});
} else if (name == "get_late_bound_consts") {
return PackedFunc([this](TVMArgs args, TVMRetValue* rv) {
CHECK_EQ(args.size(), 1);
uint64_t byte_limit = args[0];
Map<String, NDArray> consts = GetLateBoundConstants(static_cast<size_t>(byte_limit));
*rv = consts;
});
} else if (name == "load_late_bound_consts") {
return PackedFunc([this](TVMArgs args, TVMRetValue* rv) {
CHECK_EQ(args.size(), 1);
std::string path = args[0];
LoadLateBoundConstantsFromFile(path);
});
} else if (name == "load_late_bound_consts_from_map") {
return PackedFunc([this](TVMArgs args, TVMRetValue* rv) {
CHECK_EQ(args.size(), 1);
Map<String, NDArray> map = args[0];
LoadLateBoundConstantsFromMap(map);
});
}
return nullptr;
}

const VMFunction& Executable::GetVMFunctionWithName(const std::string& func_name) const {
auto it = global_map.find(func_name);
ICHECK(it != global_map.end()) << "Cannot find function " << func_name << " in executable";
return functions[it->second];
}

Module Executable::VMLoadExecutable() {
auto vm = make_object<VirtualMachine>();
vm->LoadExecutable(GetObjectPtr<Executable>(this));
return Module(vm);
}

int Executable::GetFunctionArity(std::string func_name) const {
const auto& func = GetVMFunctionWithName(func_name);
return func.params.size();
}

std::string Executable::GetFunctionParameterName(std::string func_name, uint32_t index) const {
std::string Executable::GetFunctionParameterName(std::string func_name, int index) const {
const auto& func = GetVMFunctionWithName(func_name);
ICHECK_LT(index, func.params.size()) << "Invalid parameter index";
return func.params[index];
Expand Down Expand Up @@ -311,15 +252,15 @@ void Executable::SaveVirtualDevicesSection(dmlc::Stream* strm) {
strm->Write(host_device_index);
}

Map<String, NDArray> Executable::GetLateBoundConstants(size_t byte_limit) {
Map<String, NDArray> Executable::GetLateBoundConstants(int64_t byte_limit) {
ICHECK(late_bound_constant_names.empty());
late_bound_constant_names.reserve(constants.size());
Map<String, NDArray> map;
size_t total_late_bound_bytes = 0;
for (size_t const_index = 0; const_index < constants.size(); ++const_index) {
const auto ndarray = Downcast<NDArray>(constants[const_index]);
ICHECK(ndarray.defined()) << "Undefined constant at index " << const_index;
size_t num_bytes = runtime::GetDataSize(*ndarray.operator->());
int64_t num_bytes = runtime::GetDataSize(*ndarray.operator->());
if (num_bytes < byte_limit) {
// Leave as immediate.
late_bound_constant_names.emplace_back(nullptr);
Expand All @@ -337,12 +278,12 @@ Map<String, NDArray> Executable::GetLateBoundConstants(size_t byte_limit) {
return map;
}

void Executable::MoveLateBoundConstantsToStream(dmlc::Stream* stream, size_t byte_limit) {
void Executable::MoveLateBoundConstantsToStream(dmlc::Stream* stream, int64_t byte_limit) {
Map<String, NDArray> map = GetLateBoundConstants(byte_limit);
runtime::SaveParams(stream, map);
}

void Executable::MoveLateBoundConstantsToFile(const std::string& path, size_t byte_limit) {
void Executable::MoveLateBoundConstantsToFile(const std::string& path, int64_t byte_limit) {
tvm::runtime::SimpleBinaryFileStream stream(path, "wb");
MoveLateBoundConstantsToStream(&stream, byte_limit);
}
Expand Down