From 2f35e3876cefb3cf259281e5d1b574da6ec156a0 Mon Sep 17 00:00:00 2001 From: cyx666 <737363395@qq.com> Date: Sat, 22 Jan 2022 16:01:29 -0800 Subject: [PATCH 1/8] Add PackedFuncObj --- include/tvm/runtime/packed_func.h | 195 ++++++++++++++++++--------- include/tvm/runtime/registry.h | 7 +- src/ir/op.cc | 8 +- src/runtime/c_runtime_api.cc | 32 +++-- src/runtime/registry.cc | 12 +- src/runtime/rpc/rpc_local_session.cc | 9 +- src/target/generic_func.cc | 18 +-- web/emcc/tvmjs_support.cc | 14 +- 8 files changed, 202 insertions(+), 93 deletions(-) diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index 9bfe379a3d77..12ec4b816329 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -57,6 +57,73 @@ class TVMMovableArgValueWithContext_; class TVMRetValue; class TVMArgsSetter; +/*! + * \brief Object container class that backs PackedFunc. + * \note Do not use this function directly, use PackedFunc. + */ +class PackedFuncObj : public Object { + public: + /*! + * \brief Call the function in packed format. + * \param args The arguments + * \param rv The return value. + */ + inline void CallPacked(TVMArgs args, TVMRetValue* rv) const; + + /*! \return Whether the packed function is nullptr */ + bool operator==(std::nullptr_t null) const { return f_call_ == nullptr; } + /*! \return Whether the packed function is not nullptr */ + bool operator!=(std::nullptr_t null) const { return f_call_ != nullptr; } + + static constexpr const char* _type_key = "PackedFuncObj"; + TVM_DECLARE_FINAL_OBJECT_INFO(PackedFuncObj, Object); + + protected: + /*! + * \brief Internal struct for extracting the callable method from callable type. + */ + template + struct Extractor { + /*! + * \brief extracting the callable method from callable type. + * \param obj The base packed function object class. + * \param args The arguments + * \param rv The return value. + */ + static void Call(const PackedFuncObj* obj, TVMArgs args, TVMRetValue* rv); + }; + + /*! \brief The internal callable function type. */ + using FCall = void(const PackedFuncObj*, TVMArgs, TVMRetValue*); + + /*! + * \brief Constructing a packed function object from a function pointer. + * \param f_call The function pointer used to call the packed function. + */ + explicit PackedFuncObj(FCall* f_call) : f_call_(f_call) {} + + /*! \brief Internal callable function pointer used to call the packed function. */ + FCall* f_call_; +}; + +/*! \brief Derived object class for constructing PackedFuncObj. */ +template +class PackedFuncSubObj : public PackedFuncObj { + using TStorage = typename std::remove_cv::type>::type; + + public: + /*! \brief The type of derived object class */ + using TSelf = PackedFuncSubObj; + /*! + * \brief Derived object class for constructing PackedFuncObj. + * \param callable The type-erased callable object. + */ + explicit PackedFuncSubObj(TCallable callable) + : PackedFuncObj(Extractor::Call), callable_(callable) {} + /*! \brief Type-erased filed for storing callable object*/ + mutable TStorage callable_; +}; + /*! * \brief Packed function is a type-erased function. * The arguments are passed by packed format. @@ -65,36 +132,22 @@ class TVMArgsSetter; * It is the unified function function type of TVM. * It corresponds to TVMFunctionHandle in C runtime API. */ -class PackedFunc { +class PackedFunc : public ObjectRef { public: - /*! - * \brief The internal std::function - * \param args The arguments to the function. - * \param rv The return value. - * - * \code - * // Example code on how to implemented FType - * void MyPackedFunc(TVMArgs args, TVMRetValue* rv) { - * // automatically convert arguments to desired type. - * int a0 = args[0]; - * float a1 = args[1]; - * ... - * // automatically assign values to rv - * std::string my_return_value = "x"; - * *rv = my_return_value; - * } - * \endcode - */ - using FType = std::function; - /*! \brief default constructor */ - PackedFunc() {} /*! \brief constructor from null */ - PackedFunc(std::nullptr_t null) {} // NOLINT(*) + PackedFunc(std::nullptr_t null): ObjectRef(nullptr) {} // NOLINT(*) /*! - * \brief constructing a packed function from a std::function. - * \param body the internal container of packed function. + * \brief constructing a packed function from a type-erased callable type. + * \param data the internal container of packed function. */ - explicit PackedFunc(FType body) : body_(body) {} + template >::value && + !std::is_base_of::value>> + explicit PackedFunc(TCallable data) { + using ObjType = PackedFuncSubObj; + data_ = make_object(std::forward(data)); + } /*! * \brief Call packed function by directly passing in unpacked format. * \param args Arguments to be passed. @@ -117,16 +170,12 @@ class PackedFunc { * \param rv The return value. */ inline void CallPacked(TVMArgs args, TVMRetValue* rv) const; - /*! \return the internal body function */ - inline FType body() const; /*! \return Whether the packed function is nullptr */ - bool operator==(std::nullptr_t null) const { return body_ == nullptr; } + bool operator==(std::nullptr_t null) const { return data_ == nullptr; } /*! \return Whether the packed function is not nullptr */ - bool operator!=(std::nullptr_t null) const { return body_ != nullptr; } + bool operator!=(std::nullptr_t null) const { return data_ != nullptr; } - private: - /*! \brief internal container of packed function */ - FType body_; + TVM_DEFINE_OBJECT_REF_METHODS(PackedFunc, ObjectRef, PackedFuncObj); }; /*! @@ -540,6 +589,13 @@ class TVMPODValue_ { TVM_CHECK_TYPE_CODE(type_code_, kTVMModuleHandle); return Module(ObjectPtr(static_cast(value_.v_handle))); } + operator PackedFunc() const { + if (type_code_ == kTVMNullptr) { + return PackedFunc(ObjectPtr(nullptr)); + } + TVM_CHECK_TYPE_CODE(type_code_, kTVMPackedFuncHandle); + return PackedFunc(ObjectPtr(static_cast(value_.v_handle))); + } operator Device() const { TVM_CHECK_TYPE_CODE(type_code_, kDLDevice); return value_.v_device; @@ -601,6 +657,7 @@ class TVMArgValue : public TVMPODValue_ { using TVMPODValue_::operator NDArray; using TVMPODValue_::operator Device; using TVMPODValue_::operator Module; + using TVMPODValue_::operator PackedFunc; using TVMPODValue_::AsObjectRef; using TVMPODValue_::IsObjectRef; @@ -620,11 +677,6 @@ class TVMArgValue : public TVMPODValue_ { return AsObjectRef().operator std::string(); } } - operator PackedFunc() const { - if (type_code_ == kTVMNullptr) return PackedFunc(); - TVM_CHECK_TYPE_CODE(type_code_, kTVMPackedFuncHandle); - return *ptr(); - } template operator TypedPackedFunc() const { return TypedPackedFunc(operator PackedFunc()); @@ -661,9 +713,9 @@ class TVMMovableArgValue_ : public TVMPODValue_ { using TVMPODValue_::operator NDArray; using TVMPODValue_::operator Device; using TVMPODValue_::operator Module; + using TVMPODValue_::operator PackedFunc; // reuse conversion rule from ArgValue. operator std::string() const { return AsArgValue().operator std::string(); } - operator PackedFunc() const { return AsArgValue().operator PackedFunc(); } template operator TypedPackedFunc() const { return TypedPackedFunc(operator PackedFunc()); @@ -756,6 +808,7 @@ class TVMRetValue : public TVMPODValue_ { using TVMPODValue_::operator Device; using TVMPODValue_::operator NDArray; using TVMPODValue_::operator Module; + using TVMPODValue_::operator PackedFunc; using TVMPODValue_::AsObjectRef; using TVMPODValue_::IsObjectRef; @@ -778,11 +831,6 @@ class TVMRetValue : public TVMPODValue_ { return value_.v_type; } operator DataType() const { return DataType(operator DLDataType()); } - operator PackedFunc() const { - if (type_code_ == kTVMNullptr) return PackedFunc(); - TVM_CHECK_TYPE_CODE(type_code_, kTVMPackedFuncHandle); - return *ptr(); - } template operator TypedPackedFunc() const { return TypedPackedFunc(operator PackedFunc()); @@ -860,11 +908,7 @@ class TVMRetValue : public TVMPODValue_ { return *this; } TVMRetValue& operator=(PackedFunc f) { - if (f == nullptr) { - this->SwitchToPOD(kTVMNullptr); - } else { - this->SwitchToClass(kTVMPackedFuncHandle, f); - } + this->SwitchToObject(kTVMPackedFuncHandle, std::move(f.data_)); return *this; } template @@ -941,7 +985,7 @@ class TVMRetValue : public TVMPODValue_ { break; } case kTVMPackedFuncHandle: { - SwitchToClass(kTVMPackedFuncHandle, other); + *this = other.operator PackedFunc(); break; } case kTVMModuleHandle: { @@ -1005,7 +1049,7 @@ class TVMRetValue : public TVMPODValue_ { delete ptr(); break; case kTVMPackedFuncHandle: - delete ptr(); + static_cast(value_.v_handle)->DecRef(); break; case kTVMNDArrayHandle: { NDArray::FFIDecRef(static_cast(value_.v_handle)); @@ -1148,9 +1192,19 @@ inline TVMArgValue TVMArgs::operator[](int i) const { inline int TVMArgs::size() const { return num_args; } -inline void PackedFunc::CallPacked(TVMArgs args, TVMRetValue* rv) const { body_(args, rv); } +template +void PackedFuncObj::Extractor::Call(const PackedFuncObj* obj, TVMArgs args, TVMRetValue* rv) { + (static_cast(obj))->callable_(args, rv); +} + +inline void PackedFuncObj::CallPacked(TVMArgs args, TVMRetValue* rv) const { + (*f_call_)(this, args, rv); +} + -inline PackedFunc::FType PackedFunc::body() const { return body_; } +inline void PackedFunc::CallPacked(TVMArgs args, TVMRetValue* rv) const { + (static_cast(data_.get()))->CallPacked(args, rv); +} // internal namespace inline const char* ArgTypeCode2Str(int type_code) { @@ -1312,15 +1366,6 @@ class TVMArgsSetter { values_[i].v_handle = const_cast(&value); type_codes_[i] = kTVMBytes; } - TVM_ALWAYS_INLINE void operator()(size_t i, const PackedFunc& value) const { - if (value != nullptr) { - values_[i].v_handle = const_cast(&value); - type_codes_[i] = kTVMPackedFuncHandle; - } else { - values_[i].v_handle = nullptr; - type_codes_[i] = kTVMNullptr; - } - } template TVM_ALWAYS_INLINE void operator()(size_t i, const TypedPackedFunc& value) const { operator()(i, value.packed()); @@ -1366,7 +1411,8 @@ inline TVMRetValue PackedFunc::operator()(Args&&... args) const { int type_codes[kArraySize]; detail::for_each(TVMArgsSetter(values, type_codes), std::forward(args)...); TVMRetValue rv; - body_(TVMArgs(values, type_codes, kNumArgs), &rv); + (static_cast(data_.get())) + ->CallPacked(TVMArgs(values, type_codes, kNumArgs), &rv); return rv; } @@ -1518,6 +1564,11 @@ inline void TVMArgsSetter::SetObject(size_t i, T&& value) const { ptr->IsInstance())) { values_[i].v_handle = ptr; type_codes_[i] = kTVMModuleHandle; + } else if (std::is_base_of::value || + (std::is_base_of::value && + ptr->IsInstance())) { + values_[i].v_handle = ptr; + type_codes_[i] = kTVMPackedFuncHandle; } else if (std::is_rvalue_reference::value) { values_[i].v_handle = const_cast(&(value.data_.data_)); type_codes_[i] = kTVMObjectRValueRefArg; @@ -1543,6 +1594,10 @@ inline bool TVMPODValue_::IsObjectRef() const { return type_code_ == kTVMModuleHandle && static_cast(value_.v_handle)->IsInstance(); } + if (std::is_base_of::value) { + return type_code_ == kTVMPackedFuncHandle && + static_cast(value_.v_handle)->IsInstance(); + } // NOTE: we don't pass NDArray and runtime::Module as RValue ref. if (type_code_ == kTVMObjectRValueRefArg) { return ObjectTypeChecker::Check(*static_cast(value_.v_handle)); @@ -1551,6 +1606,8 @@ inline bool TVMPODValue_::IsObjectRef() const { type_code_ == kTVMNDArrayHandle) || (std::is_base_of::value && type_code_ == kTVMModuleHandle) || + (std::is_base_of::value && + type_code_ == kTVMPackedFuncHandle) || (type_code_ == kTVMObjectHandle && ObjectTypeChecker::Check(static_cast(value_.v_handle))); } @@ -1584,6 +1641,14 @@ inline TObjectRef TVMPODValue_::AsObjectRef() const { << "Expected " << ContainerType::_type_key << " but got " << data->GetTypeKey(); return TObjectRef(data); } + if (std::is_base_of::value) { + // Casting to a sub-class of PackedFunc + TVM_CHECK_TYPE_CODE(type_code_, kTVMPackedFuncHandle); + ObjectPtr data = GetObjectPtr(static_cast(value_.v_handle)); + CHECK(data->IsInstance()) + << "Expected " << ContainerType::_type_key << " but got " << data->GetTypeKey(); + return TObjectRef(data); + } if (type_code_ == kTVMObjectHandle) { // normal object type check. Object* ptr = static_cast(value_.v_handle); @@ -1607,6 +1672,10 @@ inline TObjectRef TVMPODValue_::AsObjectRef() const { type_code_ == kTVMModuleHandle) { // Casting to a base class that Module can sub-class return TObjectRef(GetObjectPtr(static_cast(value_.v_handle))); + } else if (std::is_base_of::value && + type_code_ == kTVMPackedFuncHandle) { + // Casting to a base class that PackedFunc can sub-class + return TObjectRef(GetObjectPtr(static_cast(value_.v_handle))); } else { TVM_CHECK_TYPE_CODE(type_code_, kTVMObjectHandle); return TObjectRef(ObjectPtr(nullptr)); diff --git a/include/tvm/runtime/registry.h b/include/tvm/runtime/registry.h index 28fad3510064..316fbeb7e891 100644 --- a/include/tvm/runtime/registry.h +++ b/include/tvm/runtime/registry.h @@ -46,6 +46,7 @@ #include #include +#include #include #include @@ -108,7 +109,11 @@ class Registry { * \brief set the body of the function to be f * \param f The body of the function. */ - Registry& set_body(PackedFunc::FType f) { // NOLINT(*) + template >::value && + !std::is_base_of::value>> + Registry& set_body(TCallable f) { // NOLINT(*) return set_body(PackedFunc(f)); } /*! diff --git a/src/ir/op.cc b/src/ir/op.cc index e0bf5611c6a5..e80a10f84def 100644 --- a/src/ir/op.cc +++ b/src/ir/op.cc @@ -121,13 +121,13 @@ TVM_REGISTER_GLOBAL("ir.OpAddTypeRel") auto& reg = OpRegistry::Global()->RegisterOrGet(op->name).set_name(); if (value.type_code() == kTVMPackedFuncHandle) { // do an eager copy of the PackedFunc to avoid deleting function from frontend. - PackedFunc* fcopy = new PackedFunc(value.operator tvm::runtime::PackedFunc()); + PackedFunc fcopy = value; auto f = [=](const Array& args, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) -> bool { Array input_types(args.begin(), args.end() - 1); // call customized relation functions // *fcopy's signature: function (args: List[Type], attrs: Attrs) -> Type - Type ret_type = (*fcopy)(input_types, attrs); + Type ret_type = fcopy(input_types, attrs); // when defined ret_type, inference of output type is ok, do type assign // otherwise, inference failure happens if (ret_type.defined()) { @@ -185,9 +185,7 @@ TVM_REGISTER_GLOBAL("ir.RegisterOpAttr") if (value.type_code() == kTVMPackedFuncHandle) { // do an eager copy of the PackedFunc PackedFunc f = value; - // If we get a function from frontend, avoid deleting it. - auto* fcopy = new PackedFunc(f); - reg.set_attr(attr_key, *fcopy, plevel); + reg.set_attr(attr_key, f, plevel); } else { reg.set_attr(attr_key, value, plevel); } diff --git a/src/runtime/c_runtime_api.cc b/src/runtime/c_runtime_api.cc index 2a71dad2bd2c..0cd1578b55b2 100644 --- a/src/runtime/c_runtime_api.cc +++ b/src/runtime/c_runtime_api.cc @@ -407,7 +407,12 @@ int TVMModGetFunction(TVMModuleHandle mod, const char* func_name, int query_impo API_BEGIN(); PackedFunc pf = ObjectInternal::GetModuleNode(mod)->GetFunction(func_name, query_imports != 0); if (pf != nullptr) { - *func = new PackedFunc(pf); + tvm::runtime::TVMRetValue ret; + ret = pf; + TVMValue val; + int type_code; + ret.MoveToCHost(&val, &type_code); + *func = val.v_handle; } else { *func = nullptr; } @@ -418,7 +423,7 @@ int TVMModFree(TVMModuleHandle mod) { return TVMObjectFree(mod); } int TVMBackendGetFuncFromEnv(void* mod_node, const char* func_name, TVMFunctionHandle* func) { API_BEGIN(); - *func = (TVMFunctionHandle)(static_cast(mod_node)->GetFuncFromEnv(func_name)); + *func = (TVMFunctionHandle)(static_cast(mod_node)->GetFuncFromEnv(func_name))->get(); API_END(); } @@ -452,11 +457,7 @@ int TVMBackendRunOnce(void** handle, int (*f)(void*), void* cdata, int nbytes) { return 0; } -int TVMFuncFree(TVMFunctionHandle func) { - API_BEGIN(); - delete static_cast(func); - API_END(); -} +int TVMFuncFree(TVMFunctionHandle func) { return TVMObjectFree(func); } int TVMByteArrayFree(TVMByteArray* arr) { if (arr == &TVMAPIRuntimeStore::Get()->ret_bytes) { @@ -472,7 +473,8 @@ int TVMFuncCall(TVMFunctionHandle func, TVMValue* args, int* arg_type_codes, int API_BEGIN(); TVMRetValue rv; - (*static_cast(func)).CallPacked(TVMArgs(args, arg_type_codes, num_args), &rv); + (static_cast(func)) + ->CallPacked(TVMArgs(args, arg_type_codes, num_args), &rv); // handle return string. if (rv.type_code() == kTVMStr || rv.type_code() == kTVMDataType || rv.type_code() == kTVMBytes) { TVMRuntimeEntry* e = TVMAPIRuntimeStore::Get(); @@ -508,24 +510,34 @@ int TVMFuncCreateFromCFunc(TVMPackedCFunc func, void* resource_handle, TVMPacked TVMFunctionHandle* out) { API_BEGIN(); if (fin == nullptr) { - *out = new PackedFunc([func, resource_handle](TVMArgs args, TVMRetValue* rv) { + tvm::runtime::TVMRetValue ret; + ret = PackedFunc([func, resource_handle](TVMArgs args, TVMRetValue* rv) { int ret = func(const_cast(args.values), const_cast(args.type_codes), args.num_args, rv, resource_handle); if (ret != 0) { throw tvm::Error(TVMGetLastError() + tvm::runtime::Backtrace()); } }); + TVMValue val; + int type_code; + ret.MoveToCHost(&val, &type_code); + *out = val.v_handle; } else { // wrap it in a shared_ptr, with fin as deleter. // so fin will be called when the lambda went out of scope. std::shared_ptr rpack(resource_handle, fin); - *out = new PackedFunc([func, rpack](TVMArgs args, TVMRetValue* rv) { + tvm::runtime::TVMRetValue ret; + ret = PackedFunc([func, rpack](TVMArgs args, TVMRetValue* rv) { int ret = func(const_cast(args.values), const_cast(args.type_codes), args.num_args, rv, rpack.get()); if (ret != 0) { throw tvm::Error(TVMGetLastError() + tvm::runtime::Backtrace()); } }); + TVMValue val; + int type_code; + ret.MoveToCHost(&val, &type_code); + *out = val.v_handle; } API_END(); } diff --git a/src/runtime/registry.cc b/src/runtime/registry.cc index f15192662243..7b171f6e77c3 100644 --- a/src/runtime/registry.cc +++ b/src/runtime/registry.cc @@ -189,8 +189,11 @@ typedef dmlc::ThreadLocalStore TVMFuncThreadLocalStore; int TVMFuncRegisterGlobal(const char* name, TVMFunctionHandle f, int override) { API_BEGIN(); + using tvm::runtime::GetRef; + using tvm::runtime::PackedFunc; + using tvm::runtime::PackedFuncObj; tvm::runtime::Registry::Register(name, override != 0) - .set_body(*static_cast(f)); + .set_body(GetRef(static_cast(f))); API_END(); } @@ -198,7 +201,12 @@ int TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out) { API_BEGIN(); const tvm::runtime::PackedFunc* fp = tvm::runtime::Registry::Get(name); if (fp != nullptr) { - *out = new tvm::runtime::PackedFunc(*fp); // NOLINT(*) + tvm::runtime::TVMRetValue ret; + ret = *fp; + TVMValue val; + int type_code; + ret.MoveToCHost(&val, &type_code); + *out = val.v_handle; } else { *out = nullptr; } diff --git a/src/runtime/rpc/rpc_local_session.cc b/src/runtime/rpc/rpc_local_session.cc index 4b1c1f7fe998..d4aec5596f37 100644 --- a/src/runtime/rpc/rpc_local_session.cc +++ b/src/runtime/rpc/rpc_local_session.cc @@ -34,7 +34,12 @@ namespace runtime { RPCSession::PackedFuncHandle LocalSession::GetFunction(const std::string& name) { if (auto* fp = tvm::runtime::Registry::Get(name)) { // return raw handle because the remote need to explicitly manage it. - return new PackedFunc(*fp); + tvm::runtime::TVMRetValue ret; + ret = *fp; + TVMValue val; + int type_code; + ret.MoveToCHost(&val, &type_code); + return val.v_handle; } else { return nullptr; } @@ -81,7 +86,7 @@ void LocalSession::EncodeReturn(TVMRetValue rv, const FEncodeReturn& encode_retu void LocalSession::CallFunc(RPCSession::PackedFuncHandle func, const TVMValue* arg_values, const int* arg_type_codes, int num_args, const FEncodeReturn& encode_return) { - auto* pf = static_cast(func); + PackedFuncObj* pf = static_cast(func); TVMRetValue rv; pf->CallPacked(TVMArgs(arg_values, arg_type_codes, num_args), &rv); this->EncodeReturn(std::move(rv), encode_return); diff --git a/src/target/generic_func.cc b/src/target/generic_func.cc index a0065672139a..3135f6a9f240 100644 --- a/src/target/generic_func.cc +++ b/src/target/generic_func.cc @@ -31,6 +31,8 @@ #include #include +#include "../runtime/object_internal.h" + namespace tvm { TVM_REGISTER_NODE_TYPE(GenericFuncNode); @@ -143,26 +145,26 @@ TVM_REGISTER_GLOBAL("target.GenericFuncGetGlobal").set_body([](TVMArgs args, TVM TVM_REGISTER_GLOBAL("target.GenericFuncSetDefault").set_body([](TVMArgs args, TVMRetValue* ret) { GenericFunc generic_func = args[0]; - // Intentionally copy and not de-allocate it, to avoid free pyobject during shutdown - PackedFunc* func = new PackedFunc(args[1].operator PackedFunc()); + PackedFunc func = args[1]; bool allow_override = args[2]; - - generic_func.set_default(*func, allow_override); + // Intentionally copy and not de-allocate it, to avoid free pyobject during shutdown + runtime::ObjectInternal::ObjectRetain((TVMObjectHandle)(func.get())); + generic_func.set_default(func, allow_override); }); TVM_REGISTER_GLOBAL("target.GenericFuncRegisterFunc").set_body([](TVMArgs args, TVMRetValue* ret) { GenericFunc generic_func = args[0]; - // Intentionally copy and not de-allocate it, to avoid free pyobject during shutdown - PackedFunc* func = new PackedFunc(args[1].operator PackedFunc()); + PackedFunc func = args[1]; Array tags = args[2]; bool allow_override = args[3]; - + // Intentionally copy and not de-allocate it, to avoid free pyobject during shutdown + runtime::ObjectInternal::ObjectRetain((TVMObjectHandle)(func.get())); std::vector tags_vector; for (auto& tag : tags) { tags_vector.push_back(tag); } - generic_func.register_func(tags_vector, *func, allow_override); + generic_func.register_func(tags_vector, func, allow_override); }); TVM_REGISTER_GLOBAL("target.GenericFuncCallFunc").set_body([](TVMArgs args, TVMRetValue* ret) { diff --git a/web/emcc/tvmjs_support.cc b/web/emcc/tvmjs_support.cc index 3054bd0d7109..5a5c9361fc92 100644 --- a/web/emcc/tvmjs_support.cc +++ b/web/emcc/tvmjs_support.cc @@ -108,9 +108,19 @@ class AsyncLocalSession : public LocalSession { return get_time_eval_placeholder_.get(); } else if (auto* fp = tvm::runtime::Registry::Get(name)) { // return raw handle because the remote need to explicitly manage it. - return new PackedFunc(*fp); + tvm::runtime::TVMRetValue ret; + ret = *fp; + TVMValue val; + int type_code; + ret.MoveToCHost(&val, &type_code); + return val.v_handle; } else if (auto* fp = tvm::runtime::Registry::Get("__async." + name)) { - auto* rptr = new PackedFunc(*fp); + tvm::runtime::TVMRetValue ret; + ret = *fp; + TVMValue val; + int type_code; + ret.MoveToCHost(&val, &type_code); + auto* rptr = val.v_handle; async_func_set_.insert(rptr); return rptr; } else { From cad5cefd988b558245fba5d75712933701954ac0 Mon Sep 17 00:00:00 2001 From: cyx666 <737363395@qq.com> Date: Sun, 23 Jan 2022 14:07:25 -0800 Subject: [PATCH 2/8] Apply suggestions from code review Co-authored-by: Junru Shao --- include/tvm/runtime/packed_func.h | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index 12ec4b816329..823a1dea6680 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -68,14 +68,15 @@ class PackedFuncObj : public Object { * \param args The arguments * \param rv The return value. */ - inline void CallPacked(TVMArgs args, TVMRetValue* rv) const; + inline TVM_ALWAYS_INLINE void CallPacked(TVMArgs args, TVMRetValue* rv) const; /*! \return Whether the packed function is nullptr */ bool operator==(std::nullptr_t null) const { return f_call_ == nullptr; } /*! \return Whether the packed function is not nullptr */ bool operator!=(std::nullptr_t null) const { return f_call_ != nullptr; } - static constexpr const char* _type_key = "PackedFuncObj"; + static constexpr const uint32_t _type_index = TypeIndex::kRuntimePackedFunc; + static constexpr const char* _type_key = "runtime.PackedFunc"; TVM_DECLARE_FINAL_OBJECT_INFO(PackedFuncObj, Object); protected: @@ -94,7 +95,7 @@ class PackedFuncObj : public Object { }; /*! \brief The internal callable function type. */ - using FCall = void(const PackedFuncObj*, TVMArgs, TVMRetValue*); + using FCallPacked = void(const PackedFuncObj*, TVMArgs, TVMRetValue*); /*! * \brief Constructing a packed function object from a function pointer. @@ -103,7 +104,7 @@ class PackedFuncObj : public Object { explicit PackedFuncObj(FCall* f_call) : f_call_(f_call) {} /*! \brief Internal callable function pointer used to call the packed function. */ - FCall* f_call_; + FCallPacked* f_call_packed_; }; /*! \brief Derived object class for constructing PackedFuncObj. */ @@ -137,7 +138,7 @@ class PackedFunc : public ObjectRef { /*! \brief constructor from null */ PackedFunc(std::nullptr_t null): ObjectRef(nullptr) {} // NOLINT(*) /*! - * \brief constructing a packed function from a type-erased callable type. + * \brief constructing a packed function from a callable type whose signature is consistent with `PackedFunc` * \param data the internal container of packed function. */ template ::Call(const PackedFuncObj* obj, (static_cast(obj))->callable_(args, rv); } -inline void PackedFuncObj::CallPacked(TVMArgs args, TVMRetValue* rv) const { +inline TVM_ALWAYS_INLINE void PackedFuncObj::CallPacked(TVMArgs args, TVMRetValue* rv) const { (*f_call_)(this, args, rv); } -inline void PackedFunc::CallPacked(TVMArgs args, TVMRetValue* rv) const { +inline TVM_ALWAYS_INLINE void PackedFunc::CallPacked(TVMArgs args, TVMRetValue* rv) const { (static_cast(data_.get()))->CallPacked(args, rv); } From 74f23d3ac089a63789089e2e3060c68576b506d3 Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Sun, 23 Jan 2022 15:47:45 -0800 Subject: [PATCH 3/8] fix lint and update code review suggestions --- include/tvm/runtime/object.h | 2 ++ include/tvm/runtime/packed_func.h | 26 ++++++++++++++------------ src/runtime/c_runtime_api.cc | 2 +- 3 files changed, 17 insertions(+), 13 deletions(-) diff --git a/include/tvm/runtime/object.h b/include/tvm/runtime/object.h index 0ed61177e65a..f44c8752d94e 100644 --- a/include/tvm/runtime/object.h +++ b/include/tvm/runtime/object.h @@ -70,6 +70,8 @@ struct TypeIndex { kRuntimeMap = 5, /*! \brief runtime::ShapeTuple. */ kRuntimeShapeTuple = 6, + /*! \brief runtime::PackedFunc. */ + kRuntimePackedFunc = 7, // static assignments that may subject to change. kRuntimeClosure, kRuntimeADT, diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index 823a1dea6680..683795712600 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -68,12 +68,12 @@ class PackedFuncObj : public Object { * \param args The arguments * \param rv The return value. */ - inline TVM_ALWAYS_INLINE void CallPacked(TVMArgs args, TVMRetValue* rv) const; - + TVM_ALWAYS_INLINE void CallPacked(TVMArgs args, TVMRetValue* rv) const; + /*! \return Whether the packed function is nullptr */ - bool operator==(std::nullptr_t null) const { return f_call_ == nullptr; } + bool operator==(std::nullptr_t null) const { return f_call_packed_ == nullptr; } /*! \return Whether the packed function is not nullptr */ - bool operator!=(std::nullptr_t null) const { return f_call_ != nullptr; } + bool operator!=(std::nullptr_t null) const { return f_call_packed_ != nullptr; } static constexpr const uint32_t _type_index = TypeIndex::kRuntimePackedFunc; static constexpr const char* _type_key = "runtime.PackedFunc"; @@ -96,12 +96,12 @@ class PackedFuncObj : public Object { /*! \brief The internal callable function type. */ using FCallPacked = void(const PackedFuncObj*, TVMArgs, TVMRetValue*); - + /*! * \brief Constructing a packed function object from a function pointer. * \param f_call The function pointer used to call the packed function. */ - explicit PackedFuncObj(FCall* f_call) : f_call_(f_call) {} + explicit PackedFuncObj(FCallPacked* f_call_pack) : f_call_packed_(f_call_pack) {} /*! \brief Internal callable function pointer used to call the packed function. */ FCallPacked* f_call_packed_; @@ -170,7 +170,7 @@ class PackedFunc : public ObjectRef { * \param args The arguments * \param rv The return value. */ - inline TVM_ALWAYS_INLINE void CallPacked(TVMArgs args, TVMRetValue* rv) const; + TVM_ALWAYS_INLINE void CallPacked(TVMArgs args, TVMRetValue* rv) const; /*! \return Whether the packed function is nullptr */ bool operator==(std::nullptr_t null) const { return data_ == nullptr; } /*! \return Whether the packed function is not nullptr */ @@ -809,7 +809,7 @@ class TVMRetValue : public TVMPODValue_ { using TVMPODValue_::operator Device; using TVMPODValue_::operator NDArray; using TVMPODValue_::operator Module; - using TVMPODValue_::operator PackedFunc; + using TVMPODValue_::operator PackedFunc; using TVMPODValue_::AsObjectRef; using TVMPODValue_::IsObjectRef; @@ -1194,16 +1194,18 @@ inline TVMArgValue TVMArgs::operator[](int i) const { inline int TVMArgs::size() const { return num_args; } template -void PackedFuncObj::Extractor::Call(const PackedFuncObj* obj, TVMArgs args, TVMRetValue* rv) { +void PackedFuncObj::Extractor::Call(const PackedFuncObj* obj, + TVMArgs args, + TVMRetValue* rv) { (static_cast(obj))->callable_(args, rv); } -inline TVM_ALWAYS_INLINE void PackedFuncObj::CallPacked(TVMArgs args, TVMRetValue* rv) const { - (*f_call_)(this, args, rv); +TVM_ALWAYS_INLINE void PackedFuncObj::CallPacked(TVMArgs args, TVMRetValue* rv) const { + (*f_call_packed_)(this, args, rv); } -inline TVM_ALWAYS_INLINE void PackedFunc::CallPacked(TVMArgs args, TVMRetValue* rv) const { +TVM_ALWAYS_INLINE void PackedFunc::CallPacked(TVMArgs args, TVMRetValue* rv) const { (static_cast(data_.get()))->CallPacked(args, rv); } diff --git a/src/runtime/c_runtime_api.cc b/src/runtime/c_runtime_api.cc index 0cd1578b55b2..2c4e476ffda3 100644 --- a/src/runtime/c_runtime_api.cc +++ b/src/runtime/c_runtime_api.cc @@ -474,7 +474,7 @@ int TVMFuncCall(TVMFunctionHandle func, TVMValue* args, int* arg_type_codes, int TVMRetValue rv; (static_cast(func)) - ->CallPacked(TVMArgs(args, arg_type_codes, num_args), &rv); + ->CallPacked(TVMArgs(args, arg_type_codes, num_args), &rv); // handle return string. if (rv.type_code() == kTVMStr || rv.type_code() == kTVMDataType || rv.type_code() == kTVMBytes) { TVMRuntimeEntry* e = TVMAPIRuntimeStore::Get(); From b926586b69cb7b4866bb71f83e5f1bdc34575c90 Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Sun, 23 Jan 2022 16:31:23 -0800 Subject: [PATCH 4/8] lint fix --- include/tvm/runtime/packed_func.h | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index 683795712600..7bfc8b9e1cc2 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -81,12 +81,12 @@ class PackedFuncObj : public Object { protected: /*! - * \brief Internal struct for extracting the callable method from callable type. - */ + * \brief Internal struct for extracting the callable method from callable type. + */ template struct Extractor { - /*! - * \brief extracting the callable method from callable type. + /*! + * \brief Extracting the callable method from callable type. * \param obj The base packed function object class. * \param args The arguments * \param rv The return value. @@ -115,8 +115,8 @@ class PackedFuncSubObj : public PackedFuncObj { public: /*! \brief The type of derived object class */ using TSelf = PackedFuncSubObj; - /*! - * \brief Derived object class for constructing PackedFuncObj. + /*! + * \brief Derived object class for constructing PackedFuncObj. * \param callable The type-erased callable object. */ explicit PackedFuncSubObj(TCallable callable) @@ -135,10 +135,11 @@ class PackedFuncSubObj : public PackedFuncObj { */ class PackedFunc : public ObjectRef { public: - /*! \brief constructor from null */ - PackedFunc(std::nullptr_t null): ObjectRef(nullptr) {} // NOLINT(*) + /*! \brief Constructor from null */ + PackedFunc(std::nullptr_t null) : ObjectRef(nullptr) {} // NOLINT(*) /*! - * \brief constructing a packed function from a callable type whose signature is consistent with `PackedFunc` + * \brief Constructing a packed function from a callable type + * whose signature is consistent with `PackedFunc` * \param data the internal container of packed function. */ template -void PackedFuncObj::Extractor::Call(const PackedFuncObj* obj, - TVMArgs args, +void PackedFuncObj::Extractor::Call(const PackedFuncObj* obj, TVMArgs args, TVMRetValue* rv) { (static_cast(obj))->callable_(args, rv); } @@ -1204,7 +1204,6 @@ TVM_ALWAYS_INLINE void PackedFuncObj::CallPacked(TVMArgs args, TVMRetValue* rv) (*f_call_packed_)(this, args, rv); } - TVM_ALWAYS_INLINE void PackedFunc::CallPacked(TVMArgs args, TVMRetValue* rv) const { (static_cast(data_.get()))->CallPacked(args, rv); } From 56c94ee3a777585b38857c6c4ea3e08719aa4e6e Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Sun, 23 Jan 2022 16:40:54 -0800 Subject: [PATCH 5/8] doc fix --- include/tvm/runtime/packed_func.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index 7bfc8b9e1cc2..83727f902faf 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -99,7 +99,7 @@ class PackedFuncObj : public Object { /*! * \brief Constructing a packed function object from a function pointer. - * \param f_call The function pointer used to call the packed function. + * \param f_call_pack The function pointer used to call the packed function. */ explicit PackedFuncObj(FCallPacked* f_call_pack) : f_call_packed_(f_call_pack) {} From 4fd56aa6649934814d1bec9cfcfa985368913644 Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Mon, 31 Jan 2022 16:03:45 -0800 Subject: [PATCH 6/8] v_handle initialization for kTVMNullptr --- include/tvm/runtime/packed_func.h | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index 83727f902faf..8021dfd63fd8 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -902,6 +902,7 @@ class TVMRetValue : public TVMPODValue_ { ObjectRef::FFIClearAfterMove(&other); } else { SwitchToPOD(kTVMNullptr); + value_.v_handle = nullptr; } return *this; } @@ -1041,6 +1042,7 @@ class TVMRetValue : public TVMPODValue_ { other.data_ = nullptr; } else { SwitchToPOD(kTVMNullptr); + value_.v_handle = nullptr; } } void Clear() { @@ -1580,6 +1582,7 @@ inline void TVMArgsSetter::SetObject(size_t i, T&& value) const { } } else { type_codes_[i] = kTVMNullptr; + values_[i].v_handle = nullptr; } } @@ -1702,6 +1705,7 @@ inline TVMRetValue& TVMRetValue::operator=(TObjectRef other) { SwitchToObject(kTVMObjectHandle, std::move(other.data_)); } else { SwitchToPOD(kTVMNullptr); + value_.v_handle = nullptr; } return *this; } From 610b058eb2004a98241038c0e36cb31f76bfd5ec Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Mon, 31 Jan 2022 21:57:09 -0800 Subject: [PATCH 7/8] fix bug in TVMFunctionHandle --- apps/extension/src/tvm_ext.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/apps/extension/src/tvm_ext.cc b/apps/extension/src/tvm_ext.cc index be431bab68d1..4150d55ba3e6 100644 --- a/apps/extension/src/tvm_ext.cc +++ b/apps/extension/src/tvm_ext.cc @@ -169,7 +169,7 @@ extern "C" float TVMTestAddOne(float y) { return y + 1; } // This way can be helpful when we want to use a header only // minimum version of TVM Runtime. extern "C" int TVMExtDeclare(TVMFunctionHandle pregister) { - const PackedFunc& fregister = *static_cast(pregister); + const PackedFunc& fregister = GetRef(static_cast(pregister)); auto mul = [](TVMArgs args, TVMRetValue* rv) { int x = args[0]; int y = args[1]; From 99b0bbf7f31df8463aea2d7fbc09ece72c8ba828 Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Wed, 2 Feb 2022 21:55:19 -0800 Subject: [PATCH 8/8] Apply suggestions from code review --- include/tvm/runtime/packed_func.h | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index 8021dfd63fd8..a3bb569045eb 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -70,11 +70,6 @@ class PackedFuncObj : public Object { */ TVM_ALWAYS_INLINE void CallPacked(TVMArgs args, TVMRetValue* rv) const; - /*! \return Whether the packed function is nullptr */ - bool operator==(std::nullptr_t null) const { return f_call_packed_ == nullptr; } - /*! \return Whether the packed function is not nullptr */ - bool operator!=(std::nullptr_t null) const { return f_call_packed_ != nullptr; } - static constexpr const uint32_t _type_index = TypeIndex::kRuntimePackedFunc; static constexpr const char* _type_key = "runtime.PackedFunc"; TVM_DECLARE_FINAL_OBJECT_INFO(PackedFuncObj, Object); @@ -103,6 +98,9 @@ class PackedFuncObj : public Object { */ explicit PackedFuncObj(FCallPacked* f_call_pack) : f_call_packed_(f_call_pack) {} + /*! \brief Delete the default constructor explicitly. */ + PackedFuncObj() = delete; + /*! \brief Internal callable function pointer used to call the packed function. */ FCallPacked* f_call_packed_; };