diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index 6822159cf119..13a699ae036f 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -123,7 +123,7 @@ class PrimExpr : public BaseExpr { private: // Internal function for conversion. - friend class runtime::TVMPODValue_; + friend struct runtime::PackedFuncValueConverter; TVM_DLL static PrimExpr FromObject_(ObjectPtr ptr); }; @@ -451,22 +451,24 @@ inline const TTypeNode* RelayExprNode::type_as() const { namespace tvm { namespace runtime { -// Additional implementattion overloads for PackedFunc. -inline TVMPODValue_::operator tvm::PrimExpr() const { - if (type_code_ == kTVMNullptr) return PrimExpr(); - if (type_code_ == kDLInt) { - CHECK_LE(value_.v_int64, std::numeric_limits::max()); - CHECK_GE(value_.v_int64, std::numeric_limits::min()); - return PrimExpr(static_cast(value_.v_int64)); +template<> +struct PackedFuncValueConverter { + // common rule for both RetValue and ArgValue. + static PrimExpr From(const TVMPODValue_& val) { + if (val.type_code() == kTVMNullptr) { + return PrimExpr(ObjectPtr(nullptr)); + } + if (val.type_code() == kDLInt) { + return PrimExpr(val.operator int()); + } + if (val.type_code() == kDLFloat) { + return PrimExpr(static_cast(val.operator double())); + } + TVM_CHECK_TYPE_CODE(val.type_code(), kTVMObjectHandle); + Object* ptr = val.ptr(); + return PrimExpr::FromObject_(GetObjectPtr(ptr)); } - if (type_code_ == kDLFloat) { - return PrimExpr(static_cast(value_.v_float64)); - } - - TVM_CHECK_TYPE_CODE(type_code_, kTVMObjectHandle); - Object* ptr = static_cast(value_.v_handle); - return PrimExpr::FromObject_(ObjectPtr(ptr)); -} +}; } // namespace runtime } // namespace tvm #endif // TVM_IR_EXPR_H_ diff --git a/include/tvm/node/container.h b/include/tvm/node/container.h index cf2ac260d674..ba1edf84383e 100644 --- a/include/tvm/node/container.h +++ b/include/tvm/node/container.h @@ -26,6 +26,7 @@ #include #include #include +#include #include #include diff --git a/include/tvm/runtime/c_runtime_api.h b/include/tvm/runtime/c_runtime_api.h index 28c390a025b3..920ecfbf9b13 100644 --- a/include/tvm/runtime/c_runtime_api.h +++ b/include/tvm/runtime/c_runtime_api.h @@ -104,6 +104,7 @@ typedef enum { kTVMStr = 11U, kTVMBytes = 12U, kTVMNDArrayHandle = 13U, + kTVMObjectRValueRefArg = 14U, // Extension codes for other frameworks to integrate TVM PackedFunc. // To make sure each framework's id do not conflict, use first and // last sections to mark ranges. @@ -290,7 +291,7 @@ TVM_DLL int TVMCFuncSetReturn(TVMRetValueHandle ret, * * \return 0 when success, -1 when failure happens. */ -TVM_DLL int TVMCbArgToReturn(TVMValue* value, int code); +TVM_DLL int TVMCbArgToReturn(TVMValue* value, int* code); /*! * \brief C type of packed function. diff --git a/include/tvm/runtime/container.h b/include/tvm/runtime/container.h index 083f87f89bc9..8963f0921276 100644 --- a/include/tvm/runtime/container.h +++ b/include/tvm/runtime/container.h @@ -27,6 +27,7 @@ #include #include #include +#include #include #include @@ -590,6 +591,25 @@ inline int String::memncmp(const char* lhs, const char* rhs, size_t lhs_count, } } +template<> +struct PackedFuncValueConverter<::tvm::runtime::String> { + static String From(const TVMArgValue& val) { + if (val.IsObjectRef()) { + return val.AsObjectRef(); + } else { + return tvm::runtime::String(val.operator std::string()); + } + } + + static String From(const TVMRetValue& val) { + if (val.IsObjectRef()) { + return val.AsObjectRef(); + } else { + return tvm::runtime::String(val.operator std::string()); + } + } +}; + } // namespace runtime } // namespace tvm diff --git a/include/tvm/runtime/object.h b/include/tvm/runtime/object.h index 717bf5e72fc4..8005cf61ff83 100644 --- a/include/tvm/runtime/object.h +++ b/include/tvm/runtime/object.h @@ -477,6 +477,17 @@ class ObjectPtr { data_->IncRef(); } } + /*! + * \brief Move an ObjectPtr from an RValueRef argument. + * \param ref The rvalue reference. + * \return the moved result. + */ + static ObjectPtr MoveFromRValueRefArg(Object** ref) { + ObjectPtr ptr; + ptr.data_ = *ref; + *ref = nullptr; + return ptr; + } // friend classes friend class Object; friend class ObjectRef; @@ -489,6 +500,7 @@ class ObjectPtr { friend class TVMArgsSetter; friend class TVMRetValue; friend class TVMArgValue; + friend class TVMMovableArgValue_; template friend RelayRefType GetRef(const ObjType* ptr); template @@ -550,6 +562,10 @@ class ObjectRef { bool unique() const { return data_.unique(); } + /*! \return The use count of the ptr, for debug purposes */ + int use_count() const { + return data_.use_count(); + } /*! * \brief Try to downcast the internal Object to a * raw pointer of a corresponding type. diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index 1b3ad571dbaa..2dcb4ffcfa79 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -30,7 +30,6 @@ #include #include #include -#include #include #include #include @@ -47,15 +46,12 @@ #endif namespace tvm { -// forward declarations -class Integer; -class PrimExpr; - namespace runtime { // forward declarations class TVMArgs; class TVMArgValue; +class TVMMovableArgValue_; class TVMRetValue; class TVMArgsSetter; @@ -210,6 +206,11 @@ class TypedPackedFunc { * \param value The TVMArgValue */ inline TypedPackedFunc(const TVMArgValue& value); // NOLINT(*) + /*! + * \brief constructor from TVMMovableArgValue_ + * \param value The TVMMovableArgValue_ + */ + inline TypedPackedFunc(TVMMovableArgValue_&& value); // NOLINT(*) /*! * \brief construct from a lambda function with the same signature. * @@ -386,8 +387,8 @@ class TVMPODValue_ { } operator int() const { TVM_CHECK_TYPE_CODE(type_code_, kDLInt); - CHECK_LE(value_.v_int64, - std::numeric_limits::max()); + CHECK_LE(value_.v_int64, std::numeric_limits::max()); + CHECK_GE(value_.v_int64, std::numeric_limits::min()); return static_cast(value_.v_int64); } operator bool() const { @@ -449,9 +450,6 @@ class TVMPODValue_ { inline bool IsObjectRef() const; template inline TObjectRef AsObjectRef() const; - // ObjectRef Specializations - inline operator tvm::PrimExpr() const; - inline operator tvm::Integer() const; protected: friend class TVMArgsSetter; @@ -497,8 +495,6 @@ class TVMArgValue : public TVMPODValue_ { using TVMPODValue_::operator Module; using TVMPODValue_::IsObjectRef; using TVMPODValue_::AsObjectRef; - using TVMPODValue_::operator tvm::PrimExpr; - using TVMPODValue_::operator tvm::Integer; // conversion operator. operator std::string() const { @@ -512,13 +508,6 @@ class TVMArgValue : public TVMPODValue_ { return std::string(value_.v_str); } } - operator tvm::runtime::String() const { - if (IsObjectRef()) { - return AsObjectRef(); - } else { - return tvm::runtime::String(operator std::string()); - } - } operator DLDataType() const { if (type_code_ == kTVMStr) { return String2DLDataType(operator std::string()); @@ -547,12 +536,52 @@ class TVMArgValue : public TVMPODValue_ { const TVMValue& value() const { return value_; } + template::value>::type> inline operator T() const; }; +/*! + * \brief Internal auxiliary struct for TypedPackedFunc to indicate a movable argument. + * + * We can only construct a movable argument once from a single argument position. + * If the argument is passed as RValue reference, the result will be moved. + * We should only construct a MovableArg from an argument once, + * as the result will can moved. + * + * \note For internal development purpose only. + */ +class TVMMovableArgValue_ : public TVMArgValue { + public: + TVMMovableArgValue_(TVMValue value, int type_code) + : TVMArgValue(value, type_code) { + } + // reuse converter from parent + using TVMArgValue::operator double; + using TVMArgValue::operator int64_t; + using TVMArgValue::operator uint64_t; + using TVMArgValue::operator int; + using TVMArgValue::operator bool; + using TVMArgValue::operator void*; + using TVMArgValue::operator DLTensor*; + using TVMArgValue::operator TVMContext; + using TVMArgValue::operator std::string; + using TVMArgValue::operator DLDataType; + using TVMArgValue::operator DataType; + using TVMArgValue::operator PackedFunc; + /*! + * \brief Helper converter function. + * Try to move out an argument if possible, + * fall back to normal argument conversion rule otherwise. + */ + template::value>::type> + inline operator T() const; +}; + /*! * \brief Return Value container, * Unlike TVMArgValue, which only holds reference and do not delete @@ -591,8 +620,6 @@ class TVMRetValue : public TVMPODValue_ { using TVMPODValue_::operator Module; using TVMPODValue_::IsObjectRef; using TVMPODValue_::AsObjectRef; - using TVMPODValue_::operator tvm::PrimExpr; - using TVMPODValue_::operator tvm::Integer; TVMRetValue(const TVMRetValue& other) : TVMPODValue_() { this->Assign(other); @@ -607,13 +634,6 @@ class TVMRetValue : public TVMPODValue_ { TVM_CHECK_TYPE_CODE(type_code_, kTVMStr); return *ptr(); } - operator tvm::runtime::String() const { - if (IsObjectRef()) { - return AsObjectRef(); - } else { - return tvm::runtime::String(operator std::string()); - } - } operator DLDataType() const { if (type_code_ == kTVMStr) { return String2DLDataType(operator std::string()); @@ -723,6 +743,10 @@ class TVMRetValue : public TVMPODValue_ { this->Assign(other); return *this; } + TVMRetValue& operator=(TVMMovableArgValue_&& other) { + this->Assign(other); + return *this; + } /*! * \brief Move the value back to front-end via C API. * This marks the current container as null. @@ -806,6 +830,10 @@ class TVMRetValue : public TVMPODValue_ { static_cast(other.value_.v_handle))); break; } + case kTVMObjectRValueRefArg: { + operator=(other.operator ObjectRef()); + break; + } default: { SwitchToPOD(other.type_code()); value_ = other.value_; @@ -863,6 +891,35 @@ class TVMRetValue : public TVMPODValue_ { } }; +/*! + * \brief Type trait to specify special value conversion rules from + * TVMArgValue and TVMRetValue. + * + * The trait can be specialized to add type specific conversion logic + * from the TVMArgvalue and TVMRetValue. + * + * \tparam TObjectRef the specific ObjectRefType. + */ +template +struct PackedFuncValueConverter { + /*! + * \brief Convert a TObjectRef from an argument value. + * \param val The argument value. + * \return the converted result. + */ + static TObjectRef From(const TVMArgValue& val) { + return val.AsObjectRef(); + } + /*! + * \brief Convert a TObjectRef from a return value. + * \param val The argument value. + * \return the converted result. + */ + static TObjectRef From(const TVMRetValue& val) { + return val.AsObjectRef(); + } +}; + /*! * \brief Export a function with the PackedFunc signature * as a PackedFunc that can be loaded by LibraryModule. @@ -1132,10 +1189,24 @@ class TVMArgsSetter { // ObjectRef handling template::value>::type> - inline void operator()(size_t i, const TObjectRef& value) const; + std::is_base_of::value> + ::type> + void operator()(size_t i, const TObjectRef& value) const { + this->SetObject(i, value); + } + + template::type>::value> + ::type> + void operator()(size_t i, TObjectRef&& value) const { + this->SetObject(i, std::forward(value)); + } private: + template + inline void SetObject(size_t i, TObjectRef&& value) const; /*! \brief The values fields */ TVMValue* values_; /*! \brief The type code fields */ @@ -1163,10 +1234,13 @@ struct unpack_call_dispatcher { const TVMArgs& args_pack, TVMRetValue* rv, Args&&... unpacked_args) { + // construct a movable argument value + // which allows potential move of argument to the input of F. unpack_call_dispatcher ::run(f, args_pack, rv, std::forward(unpacked_args)..., - args_pack[index]); + TVMMovableArgValue_(args_pack.values[index], + args_pack.type_codes[index])); } }; @@ -1245,6 +1319,10 @@ template TypedPackedFunc::TypedPackedFunc(const TVMArgValue& value) : packed_(value.operator PackedFunc()) {} +template +TypedPackedFunc::TypedPackedFunc(TVMMovableArgValue_&& value) + : packed_(value.operator PackedFunc()) {} + template template inline void TypedPackedFunc::AssignTypedLambda(FType flambda) { @@ -1264,8 +1342,9 @@ inline R TypedPackedFunc::operator()(Args... args) const { // kTVMNDArrayHandle, kTVMModuleHandle, kTVMObjectHandle // // We use type traits to eliminate un-necessary checks. -template -inline void TVMArgsSetter::operator()(size_t i, const TObjectRef& value) const { +template +inline void TVMArgsSetter::SetObject(size_t i, T&& value) const { + using TObjectRef = typename std::remove_reference::type; if (value.defined()) { Object* ptr = value.data_.data_; if (std::is_base_of::value || @@ -1278,8 +1357,11 @@ inline void TVMArgsSetter::operator()(size_t i, const TObjectRef& value) const { ptr->IsInstance())) { values_[i].v_handle = ptr; type_codes_[i] = kTVMModuleHandle; + } else if (std::is_rvalue_reference::value) { + values_[i].v_handle = const_cast(&(value.data_.data_)); + type_codes_[i] = kTVMObjectRValueRefArg; } else { - values_[i].v_handle = ptr; + values_[i].v_handle = value.data_.data_; type_codes_[i] = kTVMObjectHandle; } } else { @@ -1300,6 +1382,11 @@ inline bool TVMPODValue_::IsObjectRef() const { return type_code_ == kTVMModuleHandle && static_cast(value_.v_handle)->IsInstance(); } + // NOTE: we don't pass NDArray and runtime::Module as RValue ref. + if (type_code_ == kTVMObjectRValueRefArg) { + return ObjectTypeChecker::Check( + *static_cast(value_.v_handle)); + } return (std::is_base_of::value && type_code_ == kTVMNDArrayHandle) || (std::is_base_of::value && type_code_ == kTVMModuleHandle) || @@ -1339,6 +1426,12 @@ inline TObjectRef TVMPODValue_::AsObjectRef() const { << "Expect " << ObjectTypeChecker::TypeName() << " but get " << ptr->GetTypeKey(); return TObjectRef(GetObjectPtr(ptr)); + } else if (type_code_ == kTVMObjectRValueRefArg) { + Object* ptr = *static_cast(value_.v_handle); + CHECK(ObjectTypeChecker::Check(ptr)) + << "Expect " << ObjectTypeChecker::TypeName() + << " but get " << ptr->GetTypeKey(); + return TObjectRef(GetObjectPtr(ptr)); } else if (std::is_base_of::value && type_code_ == kTVMNDArrayHandle) { // Casting to a base class that NDArray can sub-class @@ -1376,14 +1469,27 @@ inline TVMRetValue& TVMRetValue::operator=(TObjectRef other) { return *this; } + template inline TVMArgValue::operator T() const { - return AsObjectRef(); + return PackedFuncValueConverter::From(*this); +} + +template +inline TVMMovableArgValue_::operator T() const { + if (type_code_ == kTVMObjectRValueRefArg) { + auto** ref = static_cast(value_.v_handle); + if (ObjectTypeChecker::Check(*ref)) { + return T(ObjectPtr::MoveFromRValueRefArg(ref)); + } + } + // fallback + return PackedFuncValueConverter::From(*this); } template inline TVMRetValue::operator T() const { - return AsObjectRef(); + return PackedFuncValueConverter::From(*this); } inline PackedFunc Module::GetFunction(const std::string& name, bool query_imports) { diff --git a/include/tvm/tir/expr.h b/include/tvm/tir/expr.h index 6295a366c6ad..a1603d5e7bda 100644 --- a/include/tvm/tir/expr.h +++ b/include/tvm/tir/expr.h @@ -1338,20 +1338,20 @@ enum TVMStructFieldKind : int { namespace tvm { namespace runtime { // Additional implementattion overloads for PackedFunc. -inline TVMPODValue_::operator tvm::Integer() const { - if (type_code_ == kTVMNullptr) return Integer(); - if (type_code_ == kDLInt) { - CHECK_LE(value_.v_int64, std::numeric_limits::max()); - CHECK_GE(value_.v_int64, std::numeric_limits::min()); - return Integer(static_cast(value_.v_int64)); + +template<> +struct PackedFuncValueConverter { + // common rule for RetValue and ArgValue + static tvm::Integer From(const TVMPODValue_& val) { + if (val.type_code() == kTVMNullptr) { + return Integer(ObjectPtr(nullptr)); + } + if (val.type_code() == kDLInt) { + return Integer(val.operator int()); + } + return val.AsObjectRef(); } - TVM_CHECK_TYPE_CODE(type_code_, kTVMObjectHandle); - Object* ptr = static_cast(value_.v_handle); - CHECK(ObjectTypeChecker::Check(ptr)) - << "Expect type " << ObjectTypeChecker::TypeName() - << " but get " << ptr->GetTypeKey(); - return Integer(ObjectPtr(ptr)); -} +}; } // namespace runtime } // namespace tvm diff --git a/jvm/native/src/main/native/org_apache_tvm_native_c_api.cc b/jvm/native/src/main/native/org_apache_tvm_native_c_api.cc index e3446c93cc6f..b59956824d26 100644 --- a/jvm/native/src/main/native/org_apache_tvm_native_c_api.cc +++ b/jvm/native/src/main/native/org_apache_tvm_native_c_api.cc @@ -244,8 +244,9 @@ extern "C" int funcInvokeCallback(TVMValue *args, int tcode = typeCodes[i]; if (tcode == kTVMObjectHandle || tcode == kTVMPackedFuncHandle || + tcode == kTVMObjectRValueRefArg || tcode == kTVMModuleHandle) { - TVMCbArgToReturn(&arg, tcode); + TVMCbArgToReturn(&arg, &tcode); } jobject jarg = tvmRetValueToJava(env, arg, tcode); env->SetObjectArrayElement(jargs, i, jarg); diff --git a/python/tvm/_ffi/_ctypes/object.py b/python/tvm/_ffi/_ctypes/object.py index 263a76d414a8..b5dc65fd5e79 100644 --- a/python/tvm/_ffi/_ctypes/object.py +++ b/python/tvm/_ffi/_ctypes/object.py @@ -60,6 +60,9 @@ def _return_object(x): C_TO_PY_ARG_SWITCH[TypeCode.OBJECT_HANDLE] = _wrap_arg_func( _return_object, TypeCode.OBJECT_HANDLE) +C_TO_PY_ARG_SWITCH[TypeCode.OBJECT_RVALUE_REF_ARG] = _wrap_arg_func( + _return_object, TypeCode.OBJECT_RVALUE_REF_ARG) + class ObjectBase(object): """Base object for all object types""" diff --git a/python/tvm/_ffi/_ctypes/packed_func.py b/python/tvm/_ffi/_ctypes/packed_func.py index 6e1dbf55f618..11bb65504c61 100644 --- a/python/tvm/_ffi/_ctypes/packed_func.py +++ b/python/tvm/_ffi/_ctypes/packed_func.py @@ -23,7 +23,7 @@ from ..base import _LIB, get_last_ffi_error, py2cerror, check_call from ..base import c_str, string_types -from ..runtime_ctypes import DataType, TVMByteArray, TVMContext +from ..runtime_ctypes import DataType, TVMByteArray, TVMContext, ObjectRValueRef from . import ndarray as _nd from .ndarray import NDArrayBase, _make_array from .types import TVMValue, TypeCode @@ -164,6 +164,9 @@ def _make_tvm_args(args, temp_args): elif isinstance(arg, ctypes.c_void_p): values[i].v_handle = arg type_codes[i] = TypeCode.HANDLE + elif isinstance(arg, ObjectRValueRef): + values[i].v_handle = ctypes.cast(ctypes.byref(arg.obj.handle), ctypes.c_void_p) + type_codes[i] = TypeCode.OBJECT_RVALUE_REF_ARG elif callable(arg): arg = convert_to_tvm_func(arg) values[i].v_handle = arg.handle diff --git a/python/tvm/_ffi/_ctypes/types.py b/python/tvm/_ffi/_ctypes/types.py index f45748fdd4de..20be30a59b2f 100644 --- a/python/tvm/_ffi/_ctypes/types.py +++ b/python/tvm/_ffi/_ctypes/types.py @@ -73,9 +73,9 @@ def _return_context(value): def _wrap_arg_func(return_f, type_code): - tcode = ctypes.c_int(type_code) def _wrap_func(x): - check_call(_LIB.TVMCbArgToReturn(ctypes.byref(x), tcode)) + tcode = ctypes.c_int(type_code) + check_call(_LIB.TVMCbArgToReturn(ctypes.byref(x), ctypes.byref(tcode))) return return_f(x) return _wrap_func diff --git a/python/tvm/_ffi/_cython/base.pxi b/python/tvm/_ffi/_cython/base.pxi index ad281d7512d9..0da66ac2e034 100644 --- a/python/tvm/_ffi/_cython/base.pxi +++ b/python/tvm/_ffi/_cython/base.pxi @@ -37,6 +37,7 @@ cdef enum TVMTypeCode: kTVMStr = 11 kTVMBytes = 12 kTVMNDArrayHandle = 13 + kTVMObjectRefArg = 14 kTVMExtBegin = 15 cdef extern from "tvm/runtime/c_runtime_api.h": @@ -113,7 +114,7 @@ cdef extern from "tvm/runtime/c_runtime_api.h": void* resource_handle, TVMPackedCFuncFinalizer fin, TVMPackedFuncHandle *out) - int TVMCbArgToReturn(TVMValue* value, int code) + int TVMCbArgToReturn(TVMValue* value, int* code) int TVMArrayAlloc(tvm_index_t* shape, tvm_index_t ndim, DLDataType dtype, diff --git a/python/tvm/_ffi/_cython/object.pxi b/python/tvm/_ffi/_cython/object.pxi index 2a345cad684e..f2b5cc172d45 100644 --- a/python/tvm/_ffi/_cython/object.pxi +++ b/python/tvm/_ffi/_cython/object.pxi @@ -64,10 +64,7 @@ cdef class ObjectBase: property handle: def __get__(self): - if self.chandle == NULL: - return None - else: - return ctypes_handle(self.chandle) + return ctypes_handle(self.chandle) def __set__(self, value): self._set_handle(value) diff --git a/python/tvm/_ffi/_cython/packed_func.pxi b/python/tvm/_ffi/_cython/packed_func.pxi index 9d13dbd290d6..6977e108bf88 100644 --- a/python/tvm/_ffi/_cython/packed_func.pxi +++ b/python/tvm/_ffi/_cython/packed_func.pxi @@ -20,7 +20,7 @@ import traceback from cpython cimport Py_INCREF, Py_DECREF from numbers import Number, Integral from ..base import string_types, py2cerror -from ..runtime_ctypes import DataType, TVMContext, TVMByteArray +from ..runtime_ctypes import DataType, TVMContext, TVMByteArray, ObjectRValueRef cdef void tvm_callback_finalize(void* fhandle): @@ -43,8 +43,9 @@ cdef int tvm_callback(TVMValue* args, if (tcode == kTVMObjectHandle or tcode == kTVMPackedFuncHandle or tcode == kTVMModuleHandle or + tcode == kTVMObjectRefArg or tcode > kTVMExtBegin): - CALL(TVMCbArgToReturn(&value, tcode)) + CALL(TVMCbArgToReturn(&value, &tcode)) if tcode != kTVMDLTensorHandle: pyargs.append(make_ret(value, tcode)) @@ -167,6 +168,9 @@ cdef inline int make_arg(object arg, elif isinstance(arg, ctypes.c_void_p): value[0].v_handle = c_handle(arg) tcode[0] = kTVMOpaqueHandle + elif isinstance(arg, ObjectRValueRef): + value[0].v_handle = &(((arg.obj)).chandle) + tcode[0] = kTVMObjectRefArg elif callable(arg): arg = convert_to_tvm_func(arg) value[0].v_handle = (arg).chandle diff --git a/python/tvm/_ffi/runtime_ctypes.py b/python/tvm/_ffi/runtime_ctypes.py index 160cc3ec9e21..6b06ad01c9ff 100644 --- a/python/tvm/_ffi/runtime_ctypes.py +++ b/python/tvm/_ffi/runtime_ctypes.py @@ -39,6 +39,7 @@ class TypeCode(object): STR = 11 BYTES = 12 NDARRAY_HANDLE = 13 + OBJECT_RVALUE_REF_ARG = 14 EXT_BEGIN = 15 @@ -281,4 +282,18 @@ class TVMArray(ctypes.Structure): ("strides", ctypes.POINTER(tvm_shape_index_t)), ("byte_offset", ctypes.c_uint64)] + +class ObjectRValueRef: + """Represent an RValue ref to an object that can be moved. + + Parameters + ---------- + obj : tvm.runtime.Object + The object that this value refers to + """ + __slots__ = ["obj"] + def __init__(self, obj): + self.obj = obj + + TVMArrayHandle = ctypes.POINTER(TVMArray) diff --git a/python/tvm/runtime/object.py b/python/tvm/runtime/object.py index f725c1908ba5..a55eeb0cb3ee 100644 --- a/python/tvm/runtime/object.py +++ b/python/tvm/runtime/object.py @@ -19,6 +19,7 @@ import ctypes from tvm._ffi.base import _FFI_MODE, _RUNTIME_ONLY, check_call, _LIB, c_str +from tvm._ffi.runtime_ctypes import ObjectRValueRef from . import _ffi_api, _ffi_node_api try: @@ -85,5 +86,35 @@ def __setstate__(self, state): else: self.handle = None + def _move(self): + """Create an RValue reference to the object and mark the object as moved. + + This is a advanced developer API that can be useful when passing an + unique reference to an Object that you no longer needed to a function. + + A unique reference can trigger copy on write optimization that avoids + copy when we transform an object. + + Note + ---- + All the reference of the object becomes invalid after it is moved. + Be very careful when using this feature. + + Examples + -------- + + .. code-block:: python + + x = tvm.tir.Var("x", "int32") + x0 = x + some_packed_func(x._move()) + # both x0 and x will points to None after the function call. + + Returns + ------- + rvalue : The rvalue reference. + """ + return ObjectRValueRef(self) + _set_class_object(Object) diff --git a/python/tvm/runtime/object_generic.py b/python/tvm/runtime/object_generic.py index a7716df83189..ac20b67e8299 100644 --- a/python/tvm/runtime/object_generic.py +++ b/python/tvm/runtime/object_generic.py @@ -18,6 +18,7 @@ # pylint: disable=unused-import, invalid-name from numbers import Number, Integral from tvm._ffi.base import string_types +from tvm._ffi.runtime_ctypes import ObjectRValueRef from . import _ffi_node_api, _ffi_api from .object import ObjectBase, _set_class_object_generic @@ -33,7 +34,7 @@ def asobject(self): raise NotImplementedError() -ObjectTypes = (ObjectBase, NDArrayBase, Module) +ObjectTypes = (ObjectBase, NDArrayBase, Module, ObjectRValueRef) def convert_to_object(value): diff --git a/rust/frontend/src/function.rs b/rust/frontend/src/function.rs index 7f055259fe0c..d9c0e5c1aa8b 100644 --- a/rust/frontend/src/function.rs +++ b/rust/frontend/src/function.rs @@ -261,7 +261,7 @@ unsafe extern "C" fn tvm_callback( || tcode == ffi::TVMTypeCode_kTVMPackedFuncHandle as c_int || tcode == ffi::TVMTypeCode_kTVMModuleHandle as c_int { - check_call!(ffi::TVMCbArgToReturn(&mut value as *mut _, tcode)); + check_call!(ffi::TVMCbArgToReturn(&mut value as *mut _, &mut tcode as *mut _)); } local_args.push(TVMArgValue::from_tvm_value(value, tcode as u32)); } diff --git a/src/ir/transform.cc b/src/ir/transform.cc index 6e38aac92ec0..49c0ef414d4f 100644 --- a/src/ir/transform.cc +++ b/src/ir/transform.cc @@ -371,7 +371,7 @@ TVM_REGISTER_GLOBAL("transform.MakeModulePass") .set_body_typed( [](runtime::TypedPackedFunc pass_func, PassInfo pass_info) { - return ModulePass(pass_func, pass_info); + return ModulePass(pass_func, pass_info); }); TVM_REGISTER_GLOBAL("transform.RunPass") diff --git a/src/node/container.cc b/src/node/container.cc index bce2eeea0ec9..0949da0eade2 100644 --- a/src/node/container.cc +++ b/src/node/container.cc @@ -370,7 +370,6 @@ TVM_REGISTER_GLOBAL("node.MapGetItem") Object* ptr = static_cast(args[0].value().v_handle); if (ptr->IsInstance()) { - CHECK(args[1].type_code() == kTVMObjectHandle); auto* n = static_cast(ptr); auto it = n->data.find(args[1].operator ObjectRef()); CHECK(it != n->data.end()) diff --git a/src/runtime/c_runtime_api.cc b/src/runtime/c_runtime_api.cc index 6a50bae2dbe9..fb1f74da2103 100644 --- a/src/runtime/c_runtime_api.cc +++ b/src/runtime/c_runtime_api.cc @@ -577,13 +577,11 @@ int TVMStreamStreamSynchronize(int device_type, API_END(); } -int TVMCbArgToReturn(TVMValue* value, int code) { +int TVMCbArgToReturn(TVMValue* value, int* code) { API_BEGIN(); tvm::runtime::TVMRetValue rv; - rv = tvm::runtime::TVMArgValue(*value, code); - int tcode; - rv.MoveToCHost(value, &tcode); - CHECK_EQ(tcode, code); + rv = tvm::runtime::TVMMovableArgValue_(*value, *code); + rv.MoveToCHost(value, code); API_END(); } diff --git a/src/support/ffi_testing.cc b/src/support/ffi_testing.cc index 9053f6298999..90fcfff0eef3 100644 --- a/src/support/ffi_testing.cc +++ b/src/support/ffi_testing.cc @@ -107,11 +107,11 @@ TVM_REGISTER_GLOBAL("testing.ErrorTest") .set_body_typed(ErrorTest); // internal function used for debug and testing purposes -TVM_REGISTER_GLOBAL("testing.ndarray_use_count") +TVM_REGISTER_GLOBAL("testing.object_use_count") .set_body([](TVMArgs args, TVMRetValue *ret) { - runtime::NDArray nd = args[0]; - // substract the current one - *ret = (nd.use_count() - 1); + runtime::ObjectRef obj = args[0]; + // substract the current one because we always copy + // and get another value. + *ret = (obj.use_count() - 1); }); - } // namespace tvm diff --git a/tests/cpp/packed_func_test.cc b/tests/cpp/packed_func_test.cc index d0313c60d984..99b6ca25a162 100644 --- a/tests/cpp/packed_func_test.cc +++ b/tests/cpp/packed_func_test.cc @@ -20,6 +20,7 @@ #include #include #include +#include #include #include @@ -51,7 +52,7 @@ TEST(PackedFunc, Node) { Var x; Var t = PackedFunc([&](TVMArgs args, TVMRetValue* rv) { CHECK(args.num_args == 1); - CHECK(args.type_codes[0] == kTVMObjectHandle); + CHECK(args[0].IsObjectRef()); Var b = args[0]; CHECK(x.same_as(b)); *rv = b; @@ -269,6 +270,50 @@ TEST(PackedFunc, ObjectConversion) { pf2(ObjectRef(m), Module()); } +TEST(TypedPackedFunc, RValue) { + using namespace tvm; + using namespace tvm::runtime; + { + auto f = [](tir::Var x, bool move) { + if (move) { + CHECK(x.unique()); + } else { + CHECK(!x.unique()); + } + CHECK(x->name_hint == "x"); + return x; + }; + TypedPackedFunc tf(f); + + tir::Var var("x"); + CHECK(var.unique()); + f(var, false); + // move the result to the function. + tir::Var ret = f(std::move(var), true); + CHECK(!var.defined()); + } + + { + // pass child class. + auto f = [](PrimExpr x, bool move) { + if (move) { + CHECK(x.unique()); + } else { + CHECK(!x.unique()); + } + return x; + }; + TypedPackedFunc tf(f); + + tir::Var var("x"); + CHECK(var.unique()); + f(var, false); + f(std::move(var), true); + // auto conversion. + f(1, true); + } +} + int main(int argc, char ** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe"; diff --git a/tests/python/unittest/test_runtime_packed_func.py b/tests/python/unittest/test_runtime_packed_func.py index 3570fe149608..fcaceb3dcaeb 100644 --- a/tests/python/unittest/test_runtime_packed_func.py +++ b/tests/python/unittest/test_runtime_packed_func.py @@ -98,6 +98,33 @@ def test_ctx_func(ctx): x = tvm.testing.context_test(x, x.device_type, x.device_id) assert x == tvm.opencl(10) + +def test_rvalue_ref(): + def callback(x, expected_count): + assert expected_count == tvm.testing.object_use_count(x) + return x + + f = tvm.runtime.convert(callback) + + def check0(): + x = tvm.tir.Var("x", "int32") + assert tvm.testing.object_use_count(x) == 1 + f(x, 2) + y = f(x._move(), 1) + assert x.handle.value == None + + def check1(): + x = tvm.tir.Var("x", "int32") + assert tvm.testing.object_use_count(x) == 1 + y = f(x, 2) + z = f(x._move(), 2) + assert x.handle.value == None + assert y.handle.value is not None + + check0() + check1() + + def test_trace_default_action(): n = 2 x = te.placeholder((n,n,n), name="X", dtype="float32") @@ -269,7 +296,11 @@ def check_assign(dtype): for t in ["float64", "float32"]: check_assign(t) + + if __name__ == "__main__": + test_rvalue_ref() + exit(0) test_empty_array() test_get_global() test_get_callback_with_node() diff --git a/tests/python/unittest/test_runtime_rpc.py b/tests/python/unittest/test_runtime_rpc.py index 1d9b79eca875..b61e6bb9fa01 100644 --- a/tests/python/unittest/test_runtime_rpc.py +++ b/tests/python/unittest/test_runtime_rpc.py @@ -212,7 +212,7 @@ def my_module(name): if name == "get_arr": return lambda : nd elif name == "ref_count": - return lambda : tvm.testing.ndarray_use_count(nd) + return lambda : tvm.testing.object_use_count(nd) elif name == "get_elem": return lambda idx: nd.asnumpy()[idx] elif name == "get_arr_elem": diff --git a/web/tvm_runtime.js b/web/tvm_runtime.js index 0740efc5ff6a..b62b298d969e 100644 --- a/web/tvm_runtime.js +++ b/web/tvm_runtime.js @@ -105,6 +105,7 @@ var tvm_runtime = tvm_runtime || {}; var kTVMPackedFuncHandle = 10; var kTVMStr = 11; var kTVMBytes = 12; + var kTVMObjectRValueRefArg = 14; //----------------------------------------- // TVM CWrap library // ---------------------------------------- @@ -171,7 +172,7 @@ var tvm_runtime = tvm_runtime || {}; ("TVMCbArgToReturn", "number", ["number", // TVMValue* value - "number" // int code + "number" // int* code ]); var TVMFuncCreateFromCFunc = Module.cwrap @@ -496,12 +497,15 @@ var tvm_runtime = tvm_runtime || {}; var args = []; for (var i = 0; i < nargs; ++i) { var vptr = arg_value + i * SIZEOF_TVMVALUE; - var tcode = Module.getValue(arg_tcode + i * SIZEOF_INT, "i32"); + var tcodeptr = arg_tcode + i * SIZEOF_INT; + var tcode = Module.getValue(tcodeptr, "i32"); if (tcode == kTVMObjectHandle || + tcode == kTVMObjectRValueRefArg || tcode == kTVMPackedFuncHandle || tcode == kTVMModuleHandle) { - TVM_CALL(TVMCbArgToReturn(vptr, tcode)); + TVM_CALL(TVMCbArgToReturn(vptr, tcodeptr)); } + tcode = Module.getValue(tcodeptr, "i32"); args.push(TVMRetValueToJS(vptr, tcode)); } var rv = funcTable[handle].apply(null, args);