diff --git a/apps/extension/src/tvm_ext.cc b/apps/extension/src/tvm_ext.cc index be431bab68d1..4150d55ba3e6 100644 --- a/apps/extension/src/tvm_ext.cc +++ b/apps/extension/src/tvm_ext.cc @@ -169,7 +169,7 @@ extern "C" float TVMTestAddOne(float y) { return y + 1; } // This way can be helpful when we want to use a header only // minimum version of TVM Runtime. extern "C" int TVMExtDeclare(TVMFunctionHandle pregister) { - const PackedFunc& fregister = *static_cast(pregister); + const PackedFunc& fregister = GetRef(static_cast(pregister)); auto mul = [](TVMArgs args, TVMRetValue* rv) { int x = args[0]; int y = args[1]; diff --git a/include/tvm/runtime/object.h b/include/tvm/runtime/object.h index 0ed61177e65a..f44c8752d94e 100644 --- a/include/tvm/runtime/object.h +++ b/include/tvm/runtime/object.h @@ -70,6 +70,8 @@ struct TypeIndex { kRuntimeMap = 5, /*! \brief runtime::ShapeTuple. */ kRuntimeShapeTuple = 6, + /*! \brief runtime::PackedFunc. */ + kRuntimePackedFunc = 7, // static assignments that may subject to change. kRuntimeClosure, kRuntimeADT, diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index 9bfe379a3d77..a3bb569045eb 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -57,6 +57,72 @@ class TVMMovableArgValueWithContext_; class TVMRetValue; class TVMArgsSetter; +/*! + * \brief Object container class that backs PackedFunc. + * \note Do not use this function directly, use PackedFunc. + */ +class PackedFuncObj : public Object { + public: + /*! + * \brief Call the function in packed format. + * \param args The arguments + * \param rv The return value. + */ + TVM_ALWAYS_INLINE void CallPacked(TVMArgs args, TVMRetValue* rv) const; + + static constexpr const uint32_t _type_index = TypeIndex::kRuntimePackedFunc; + static constexpr const char* _type_key = "runtime.PackedFunc"; + TVM_DECLARE_FINAL_OBJECT_INFO(PackedFuncObj, Object); + + protected: + /*! + * \brief Internal struct for extracting the callable method from callable type. + */ + template + struct Extractor { + /*! + * \brief Extracting the callable method from callable type. + * \param obj The base packed function object class. + * \param args The arguments + * \param rv The return value. + */ + static void Call(const PackedFuncObj* obj, TVMArgs args, TVMRetValue* rv); + }; + + /*! \brief The internal callable function type. */ + using FCallPacked = void(const PackedFuncObj*, TVMArgs, TVMRetValue*); + + /*! + * \brief Constructing a packed function object from a function pointer. + * \param f_call_pack The function pointer used to call the packed function. + */ + explicit PackedFuncObj(FCallPacked* f_call_pack) : f_call_packed_(f_call_pack) {} + + /*! \brief Delete the default constructor explicitly. */ + PackedFuncObj() = delete; + + /*! \brief Internal callable function pointer used to call the packed function. */ + FCallPacked* f_call_packed_; +}; + +/*! \brief Derived object class for constructing PackedFuncObj. */ +template +class PackedFuncSubObj : public PackedFuncObj { + using TStorage = typename std::remove_cv::type>::type; + + public: + /*! \brief The type of derived object class */ + using TSelf = PackedFuncSubObj; + /*! + * \brief Derived object class for constructing PackedFuncObj. + * \param callable The type-erased callable object. + */ + explicit PackedFuncSubObj(TCallable callable) + : PackedFuncObj(Extractor::Call), callable_(callable) {} + /*! \brief Type-erased filed for storing callable object*/ + mutable TStorage callable_; +}; + /*! * \brief Packed function is a type-erased function. * The arguments are passed by packed format. @@ -65,36 +131,23 @@ class TVMArgsSetter; * It is the unified function function type of TVM. * It corresponds to TVMFunctionHandle in C runtime API. */ -class PackedFunc { +class PackedFunc : public ObjectRef { public: + /*! \brief Constructor from null */ + PackedFunc(std::nullptr_t null) : ObjectRef(nullptr) {} // NOLINT(*) /*! - * \brief The internal std::function - * \param args The arguments to the function. - * \param rv The return value. - * - * \code - * // Example code on how to implemented FType - * void MyPackedFunc(TVMArgs args, TVMRetValue* rv) { - * // automatically convert arguments to desired type. - * int a0 = args[0]; - * float a1 = args[1]; - * ... - * // automatically assign values to rv - * std::string my_return_value = "x"; - * *rv = my_return_value; - * } - * \endcode - */ - using FType = std::function; - /*! \brief default constructor */ - PackedFunc() {} - /*! \brief constructor from null */ - PackedFunc(std::nullptr_t null) {} // NOLINT(*) - /*! - * \brief constructing a packed function from a std::function. - * \param body the internal container of packed function. + * \brief Constructing a packed function from a callable type + * whose signature is consistent with `PackedFunc` + * \param data the internal container of packed function. */ - explicit PackedFunc(FType body) : body_(body) {} + template >::value && + !std::is_base_of::value>> + explicit PackedFunc(TCallable data) { + using ObjType = PackedFuncSubObj; + data_ = make_object(std::forward(data)); + } /*! * \brief Call packed function by directly passing in unpacked format. * \param args Arguments to be passed. @@ -116,17 +169,13 @@ class PackedFunc { * \param args The arguments * \param rv The return value. */ - inline void CallPacked(TVMArgs args, TVMRetValue* rv) const; - /*! \return the internal body function */ - inline FType body() const; + TVM_ALWAYS_INLINE void CallPacked(TVMArgs args, TVMRetValue* rv) const; /*! \return Whether the packed function is nullptr */ - bool operator==(std::nullptr_t null) const { return body_ == nullptr; } + bool operator==(std::nullptr_t null) const { return data_ == nullptr; } /*! \return Whether the packed function is not nullptr */ - bool operator!=(std::nullptr_t null) const { return body_ != nullptr; } + bool operator!=(std::nullptr_t null) const { return data_ != nullptr; } - private: - /*! \brief internal container of packed function */ - FType body_; + TVM_DEFINE_OBJECT_REF_METHODS(PackedFunc, ObjectRef, PackedFuncObj); }; /*! @@ -540,6 +589,13 @@ class TVMPODValue_ { TVM_CHECK_TYPE_CODE(type_code_, kTVMModuleHandle); return Module(ObjectPtr(static_cast(value_.v_handle))); } + operator PackedFunc() const { + if (type_code_ == kTVMNullptr) { + return PackedFunc(ObjectPtr(nullptr)); + } + TVM_CHECK_TYPE_CODE(type_code_, kTVMPackedFuncHandle); + return PackedFunc(ObjectPtr(static_cast(value_.v_handle))); + } operator Device() const { TVM_CHECK_TYPE_CODE(type_code_, kDLDevice); return value_.v_device; @@ -601,6 +657,7 @@ class TVMArgValue : public TVMPODValue_ { using TVMPODValue_::operator NDArray; using TVMPODValue_::operator Device; using TVMPODValue_::operator Module; + using TVMPODValue_::operator PackedFunc; using TVMPODValue_::AsObjectRef; using TVMPODValue_::IsObjectRef; @@ -620,11 +677,6 @@ class TVMArgValue : public TVMPODValue_ { return AsObjectRef().operator std::string(); } } - operator PackedFunc() const { - if (type_code_ == kTVMNullptr) return PackedFunc(); - TVM_CHECK_TYPE_CODE(type_code_, kTVMPackedFuncHandle); - return *ptr(); - } template operator TypedPackedFunc() const { return TypedPackedFunc(operator PackedFunc()); @@ -661,9 +713,9 @@ class TVMMovableArgValue_ : public TVMPODValue_ { using TVMPODValue_::operator NDArray; using TVMPODValue_::operator Device; using TVMPODValue_::operator Module; + using TVMPODValue_::operator PackedFunc; // reuse conversion rule from ArgValue. operator std::string() const { return AsArgValue().operator std::string(); } - operator PackedFunc() const { return AsArgValue().operator PackedFunc(); } template operator TypedPackedFunc() const { return TypedPackedFunc(operator PackedFunc()); @@ -756,6 +808,7 @@ class TVMRetValue : public TVMPODValue_ { using TVMPODValue_::operator Device; using TVMPODValue_::operator NDArray; using TVMPODValue_::operator Module; + using TVMPODValue_::operator PackedFunc; using TVMPODValue_::AsObjectRef; using TVMPODValue_::IsObjectRef; @@ -778,11 +831,6 @@ class TVMRetValue : public TVMPODValue_ { return value_.v_type; } operator DataType() const { return DataType(operator DLDataType()); } - operator PackedFunc() const { - if (type_code_ == kTVMNullptr) return PackedFunc(); - TVM_CHECK_TYPE_CODE(type_code_, kTVMPackedFuncHandle); - return *ptr(); - } template operator TypedPackedFunc() const { return TypedPackedFunc(operator PackedFunc()); @@ -852,6 +900,7 @@ class TVMRetValue : public TVMPODValue_ { ObjectRef::FFIClearAfterMove(&other); } else { SwitchToPOD(kTVMNullptr); + value_.v_handle = nullptr; } return *this; } @@ -860,11 +909,7 @@ class TVMRetValue : public TVMPODValue_ { return *this; } TVMRetValue& operator=(PackedFunc f) { - if (f == nullptr) { - this->SwitchToPOD(kTVMNullptr); - } else { - this->SwitchToClass(kTVMPackedFuncHandle, f); - } + this->SwitchToObject(kTVMPackedFuncHandle, std::move(f.data_)); return *this; } template @@ -941,7 +986,7 @@ class TVMRetValue : public TVMPODValue_ { break; } case kTVMPackedFuncHandle: { - SwitchToClass(kTVMPackedFuncHandle, other); + *this = other.operator PackedFunc(); break; } case kTVMModuleHandle: { @@ -995,6 +1040,7 @@ class TVMRetValue : public TVMPODValue_ { other.data_ = nullptr; } else { SwitchToPOD(kTVMNullptr); + value_.v_handle = nullptr; } } void Clear() { @@ -1005,7 +1051,7 @@ class TVMRetValue : public TVMPODValue_ { delete ptr(); break; case kTVMPackedFuncHandle: - delete ptr(); + static_cast(value_.v_handle)->DecRef(); break; case kTVMNDArrayHandle: { NDArray::FFIDecRef(static_cast(value_.v_handle)); @@ -1148,9 +1194,19 @@ inline TVMArgValue TVMArgs::operator[](int i) const { inline int TVMArgs::size() const { return num_args; } -inline void PackedFunc::CallPacked(TVMArgs args, TVMRetValue* rv) const { body_(args, rv); } +template +void PackedFuncObj::Extractor::Call(const PackedFuncObj* obj, TVMArgs args, + TVMRetValue* rv) { + (static_cast(obj))->callable_(args, rv); +} -inline PackedFunc::FType PackedFunc::body() const { return body_; } +TVM_ALWAYS_INLINE void PackedFuncObj::CallPacked(TVMArgs args, TVMRetValue* rv) const { + (*f_call_packed_)(this, args, rv); +} + +TVM_ALWAYS_INLINE void PackedFunc::CallPacked(TVMArgs args, TVMRetValue* rv) const { + (static_cast(data_.get()))->CallPacked(args, rv); +} // internal namespace inline const char* ArgTypeCode2Str(int type_code) { @@ -1312,15 +1368,6 @@ class TVMArgsSetter { values_[i].v_handle = const_cast(&value); type_codes_[i] = kTVMBytes; } - TVM_ALWAYS_INLINE void operator()(size_t i, const PackedFunc& value) const { - if (value != nullptr) { - values_[i].v_handle = const_cast(&value); - type_codes_[i] = kTVMPackedFuncHandle; - } else { - values_[i].v_handle = nullptr; - type_codes_[i] = kTVMNullptr; - } - } template TVM_ALWAYS_INLINE void operator()(size_t i, const TypedPackedFunc& value) const { operator()(i, value.packed()); @@ -1366,7 +1413,8 @@ inline TVMRetValue PackedFunc::operator()(Args&&... args) const { int type_codes[kArraySize]; detail::for_each(TVMArgsSetter(values, type_codes), std::forward(args)...); TVMRetValue rv; - body_(TVMArgs(values, type_codes, kNumArgs), &rv); + (static_cast(data_.get())) + ->CallPacked(TVMArgs(values, type_codes, kNumArgs), &rv); return rv; } @@ -1518,6 +1566,11 @@ inline void TVMArgsSetter::SetObject(size_t i, T&& value) const { ptr->IsInstance())) { values_[i].v_handle = ptr; type_codes_[i] = kTVMModuleHandle; + } else if (std::is_base_of::value || + (std::is_base_of::value && + ptr->IsInstance())) { + values_[i].v_handle = ptr; + type_codes_[i] = kTVMPackedFuncHandle; } else if (std::is_rvalue_reference::value) { values_[i].v_handle = const_cast(&(value.data_.data_)); type_codes_[i] = kTVMObjectRValueRefArg; @@ -1527,6 +1580,7 @@ inline void TVMArgsSetter::SetObject(size_t i, T&& value) const { } } else { type_codes_[i] = kTVMNullptr; + values_[i].v_handle = nullptr; } } @@ -1543,6 +1597,10 @@ inline bool TVMPODValue_::IsObjectRef() const { return type_code_ == kTVMModuleHandle && static_cast(value_.v_handle)->IsInstance(); } + if (std::is_base_of::value) { + return type_code_ == kTVMPackedFuncHandle && + static_cast(value_.v_handle)->IsInstance(); + } // NOTE: we don't pass NDArray and runtime::Module as RValue ref. if (type_code_ == kTVMObjectRValueRefArg) { return ObjectTypeChecker::Check(*static_cast(value_.v_handle)); @@ -1551,6 +1609,8 @@ inline bool TVMPODValue_::IsObjectRef() const { type_code_ == kTVMNDArrayHandle) || (std::is_base_of::value && type_code_ == kTVMModuleHandle) || + (std::is_base_of::value && + type_code_ == kTVMPackedFuncHandle) || (type_code_ == kTVMObjectHandle && ObjectTypeChecker::Check(static_cast(value_.v_handle))); } @@ -1584,6 +1644,14 @@ inline TObjectRef TVMPODValue_::AsObjectRef() const { << "Expected " << ContainerType::_type_key << " but got " << data->GetTypeKey(); return TObjectRef(data); } + if (std::is_base_of::value) { + // Casting to a sub-class of PackedFunc + TVM_CHECK_TYPE_CODE(type_code_, kTVMPackedFuncHandle); + ObjectPtr data = GetObjectPtr(static_cast(value_.v_handle)); + CHECK(data->IsInstance()) + << "Expected " << ContainerType::_type_key << " but got " << data->GetTypeKey(); + return TObjectRef(data); + } if (type_code_ == kTVMObjectHandle) { // normal object type check. Object* ptr = static_cast(value_.v_handle); @@ -1607,6 +1675,10 @@ inline TObjectRef TVMPODValue_::AsObjectRef() const { type_code_ == kTVMModuleHandle) { // Casting to a base class that Module can sub-class return TObjectRef(GetObjectPtr(static_cast(value_.v_handle))); + } else if (std::is_base_of::value && + type_code_ == kTVMPackedFuncHandle) { + // Casting to a base class that PackedFunc can sub-class + return TObjectRef(GetObjectPtr(static_cast(value_.v_handle))); } else { TVM_CHECK_TYPE_CODE(type_code_, kTVMObjectHandle); return TObjectRef(ObjectPtr(nullptr)); @@ -1631,6 +1703,7 @@ inline TVMRetValue& TVMRetValue::operator=(TObjectRef other) { SwitchToObject(kTVMObjectHandle, std::move(other.data_)); } else { SwitchToPOD(kTVMNullptr); + value_.v_handle = nullptr; } return *this; } diff --git a/include/tvm/runtime/registry.h b/include/tvm/runtime/registry.h index 28fad3510064..316fbeb7e891 100644 --- a/include/tvm/runtime/registry.h +++ b/include/tvm/runtime/registry.h @@ -46,6 +46,7 @@ #include #include +#include #include #include @@ -108,7 +109,11 @@ class Registry { * \brief set the body of the function to be f * \param f The body of the function. */ - Registry& set_body(PackedFunc::FType f) { // NOLINT(*) + template >::value && + !std::is_base_of::value>> + Registry& set_body(TCallable f) { // NOLINT(*) return set_body(PackedFunc(f)); } /*! diff --git a/src/ir/op.cc b/src/ir/op.cc index e0bf5611c6a5..e80a10f84def 100644 --- a/src/ir/op.cc +++ b/src/ir/op.cc @@ -121,13 +121,13 @@ TVM_REGISTER_GLOBAL("ir.OpAddTypeRel") auto& reg = OpRegistry::Global()->RegisterOrGet(op->name).set_name(); if (value.type_code() == kTVMPackedFuncHandle) { // do an eager copy of the PackedFunc to avoid deleting function from frontend. - PackedFunc* fcopy = new PackedFunc(value.operator tvm::runtime::PackedFunc()); + PackedFunc fcopy = value; auto f = [=](const Array& args, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) -> bool { Array input_types(args.begin(), args.end() - 1); // call customized relation functions // *fcopy's signature: function (args: List[Type], attrs: Attrs) -> Type - Type ret_type = (*fcopy)(input_types, attrs); + Type ret_type = fcopy(input_types, attrs); // when defined ret_type, inference of output type is ok, do type assign // otherwise, inference failure happens if (ret_type.defined()) { @@ -185,9 +185,7 @@ TVM_REGISTER_GLOBAL("ir.RegisterOpAttr") if (value.type_code() == kTVMPackedFuncHandle) { // do an eager copy of the PackedFunc PackedFunc f = value; - // If we get a function from frontend, avoid deleting it. - auto* fcopy = new PackedFunc(f); - reg.set_attr(attr_key, *fcopy, plevel); + reg.set_attr(attr_key, f, plevel); } else { reg.set_attr(attr_key, value, plevel); } diff --git a/src/runtime/c_runtime_api.cc b/src/runtime/c_runtime_api.cc index 2a71dad2bd2c..2c4e476ffda3 100644 --- a/src/runtime/c_runtime_api.cc +++ b/src/runtime/c_runtime_api.cc @@ -407,7 +407,12 @@ int TVMModGetFunction(TVMModuleHandle mod, const char* func_name, int query_impo API_BEGIN(); PackedFunc pf = ObjectInternal::GetModuleNode(mod)->GetFunction(func_name, query_imports != 0); if (pf != nullptr) { - *func = new PackedFunc(pf); + tvm::runtime::TVMRetValue ret; + ret = pf; + TVMValue val; + int type_code; + ret.MoveToCHost(&val, &type_code); + *func = val.v_handle; } else { *func = nullptr; } @@ -418,7 +423,7 @@ int TVMModFree(TVMModuleHandle mod) { return TVMObjectFree(mod); } int TVMBackendGetFuncFromEnv(void* mod_node, const char* func_name, TVMFunctionHandle* func) { API_BEGIN(); - *func = (TVMFunctionHandle)(static_cast(mod_node)->GetFuncFromEnv(func_name)); + *func = (TVMFunctionHandle)(static_cast(mod_node)->GetFuncFromEnv(func_name))->get(); API_END(); } @@ -452,11 +457,7 @@ int TVMBackendRunOnce(void** handle, int (*f)(void*), void* cdata, int nbytes) { return 0; } -int TVMFuncFree(TVMFunctionHandle func) { - API_BEGIN(); - delete static_cast(func); - API_END(); -} +int TVMFuncFree(TVMFunctionHandle func) { return TVMObjectFree(func); } int TVMByteArrayFree(TVMByteArray* arr) { if (arr == &TVMAPIRuntimeStore::Get()->ret_bytes) { @@ -472,7 +473,8 @@ int TVMFuncCall(TVMFunctionHandle func, TVMValue* args, int* arg_type_codes, int API_BEGIN(); TVMRetValue rv; - (*static_cast(func)).CallPacked(TVMArgs(args, arg_type_codes, num_args), &rv); + (static_cast(func)) + ->CallPacked(TVMArgs(args, arg_type_codes, num_args), &rv); // handle return string. if (rv.type_code() == kTVMStr || rv.type_code() == kTVMDataType || rv.type_code() == kTVMBytes) { TVMRuntimeEntry* e = TVMAPIRuntimeStore::Get(); @@ -508,24 +510,34 @@ int TVMFuncCreateFromCFunc(TVMPackedCFunc func, void* resource_handle, TVMPacked TVMFunctionHandle* out) { API_BEGIN(); if (fin == nullptr) { - *out = new PackedFunc([func, resource_handle](TVMArgs args, TVMRetValue* rv) { + tvm::runtime::TVMRetValue ret; + ret = PackedFunc([func, resource_handle](TVMArgs args, TVMRetValue* rv) { int ret = func(const_cast(args.values), const_cast(args.type_codes), args.num_args, rv, resource_handle); if (ret != 0) { throw tvm::Error(TVMGetLastError() + tvm::runtime::Backtrace()); } }); + TVMValue val; + int type_code; + ret.MoveToCHost(&val, &type_code); + *out = val.v_handle; } else { // wrap it in a shared_ptr, with fin as deleter. // so fin will be called when the lambda went out of scope. std::shared_ptr rpack(resource_handle, fin); - *out = new PackedFunc([func, rpack](TVMArgs args, TVMRetValue* rv) { + tvm::runtime::TVMRetValue ret; + ret = PackedFunc([func, rpack](TVMArgs args, TVMRetValue* rv) { int ret = func(const_cast(args.values), const_cast(args.type_codes), args.num_args, rv, rpack.get()); if (ret != 0) { throw tvm::Error(TVMGetLastError() + tvm::runtime::Backtrace()); } }); + TVMValue val; + int type_code; + ret.MoveToCHost(&val, &type_code); + *out = val.v_handle; } API_END(); } diff --git a/src/runtime/registry.cc b/src/runtime/registry.cc index f15192662243..7b171f6e77c3 100644 --- a/src/runtime/registry.cc +++ b/src/runtime/registry.cc @@ -189,8 +189,11 @@ typedef dmlc::ThreadLocalStore TVMFuncThreadLocalStore; int TVMFuncRegisterGlobal(const char* name, TVMFunctionHandle f, int override) { API_BEGIN(); + using tvm::runtime::GetRef; + using tvm::runtime::PackedFunc; + using tvm::runtime::PackedFuncObj; tvm::runtime::Registry::Register(name, override != 0) - .set_body(*static_cast(f)); + .set_body(GetRef(static_cast(f))); API_END(); } @@ -198,7 +201,12 @@ int TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out) { API_BEGIN(); const tvm::runtime::PackedFunc* fp = tvm::runtime::Registry::Get(name); if (fp != nullptr) { - *out = new tvm::runtime::PackedFunc(*fp); // NOLINT(*) + tvm::runtime::TVMRetValue ret; + ret = *fp; + TVMValue val; + int type_code; + ret.MoveToCHost(&val, &type_code); + *out = val.v_handle; } else { *out = nullptr; } diff --git a/src/runtime/rpc/rpc_local_session.cc b/src/runtime/rpc/rpc_local_session.cc index 4b1c1f7fe998..d4aec5596f37 100644 --- a/src/runtime/rpc/rpc_local_session.cc +++ b/src/runtime/rpc/rpc_local_session.cc @@ -34,7 +34,12 @@ namespace runtime { RPCSession::PackedFuncHandle LocalSession::GetFunction(const std::string& name) { if (auto* fp = tvm::runtime::Registry::Get(name)) { // return raw handle because the remote need to explicitly manage it. - return new PackedFunc(*fp); + tvm::runtime::TVMRetValue ret; + ret = *fp; + TVMValue val; + int type_code; + ret.MoveToCHost(&val, &type_code); + return val.v_handle; } else { return nullptr; } @@ -81,7 +86,7 @@ void LocalSession::EncodeReturn(TVMRetValue rv, const FEncodeReturn& encode_retu void LocalSession::CallFunc(RPCSession::PackedFuncHandle func, const TVMValue* arg_values, const int* arg_type_codes, int num_args, const FEncodeReturn& encode_return) { - auto* pf = static_cast(func); + PackedFuncObj* pf = static_cast(func); TVMRetValue rv; pf->CallPacked(TVMArgs(arg_values, arg_type_codes, num_args), &rv); this->EncodeReturn(std::move(rv), encode_return); diff --git a/src/target/generic_func.cc b/src/target/generic_func.cc index a0065672139a..3135f6a9f240 100644 --- a/src/target/generic_func.cc +++ b/src/target/generic_func.cc @@ -31,6 +31,8 @@ #include #include +#include "../runtime/object_internal.h" + namespace tvm { TVM_REGISTER_NODE_TYPE(GenericFuncNode); @@ -143,26 +145,26 @@ TVM_REGISTER_GLOBAL("target.GenericFuncGetGlobal").set_body([](TVMArgs args, TVM TVM_REGISTER_GLOBAL("target.GenericFuncSetDefault").set_body([](TVMArgs args, TVMRetValue* ret) { GenericFunc generic_func = args[0]; - // Intentionally copy and not de-allocate it, to avoid free pyobject during shutdown - PackedFunc* func = new PackedFunc(args[1].operator PackedFunc()); + PackedFunc func = args[1]; bool allow_override = args[2]; - - generic_func.set_default(*func, allow_override); + // Intentionally copy and not de-allocate it, to avoid free pyobject during shutdown + runtime::ObjectInternal::ObjectRetain((TVMObjectHandle)(func.get())); + generic_func.set_default(func, allow_override); }); TVM_REGISTER_GLOBAL("target.GenericFuncRegisterFunc").set_body([](TVMArgs args, TVMRetValue* ret) { GenericFunc generic_func = args[0]; - // Intentionally copy and not de-allocate it, to avoid free pyobject during shutdown - PackedFunc* func = new PackedFunc(args[1].operator PackedFunc()); + PackedFunc func = args[1]; Array tags = args[2]; bool allow_override = args[3]; - + // Intentionally copy and not de-allocate it, to avoid free pyobject during shutdown + runtime::ObjectInternal::ObjectRetain((TVMObjectHandle)(func.get())); std::vector tags_vector; for (auto& tag : tags) { tags_vector.push_back(tag); } - generic_func.register_func(tags_vector, *func, allow_override); + generic_func.register_func(tags_vector, func, allow_override); }); TVM_REGISTER_GLOBAL("target.GenericFuncCallFunc").set_body([](TVMArgs args, TVMRetValue* ret) { diff --git a/web/emcc/tvmjs_support.cc b/web/emcc/tvmjs_support.cc index 3054bd0d7109..5a5c9361fc92 100644 --- a/web/emcc/tvmjs_support.cc +++ b/web/emcc/tvmjs_support.cc @@ -108,9 +108,19 @@ class AsyncLocalSession : public LocalSession { return get_time_eval_placeholder_.get(); } else if (auto* fp = tvm::runtime::Registry::Get(name)) { // return raw handle because the remote need to explicitly manage it. - return new PackedFunc(*fp); + tvm::runtime::TVMRetValue ret; + ret = *fp; + TVMValue val; + int type_code; + ret.MoveToCHost(&val, &type_code); + return val.v_handle; } else if (auto* fp = tvm::runtime::Registry::Get("__async." + name)) { - auto* rptr = new PackedFunc(*fp); + tvm::runtime::TVMRetValue ret; + ret = *fp; + TVMValue val; + int type_code; + ret.MoveToCHost(&val, &type_code); + auto* rptr = val.v_handle; async_func_set_.insert(rptr); return rptr; } else {