From 8128fdd457589dd86b5f90c01da49e002bd4b8fc Mon Sep 17 00:00:00 2001 From: tqchen Date: Tue, 7 Apr 2020 11:23:48 -0700 Subject: [PATCH 1/3] [RUNTIME] Introduce RValue reference(move) support to TypedPackedFunc This PR introduces RValue reference support the PackedFunc calling convention to address the above issue. Specifically, when an argument is a r-value reference, we will use a assign a different type code(`kObjectRValueRefArg`), and pass `Object**` (the address to the Object pointer) instead through the values array. The callee can choose to move out this Object pointer and set the original Object pointer from the caller side to be nullptr. We also add an experimental move support to the python side(marked as _move so to indicate the dev nature). This enhancement will enable copy on write optimizations through out the TVM stack. --- include/tvm/ir/expr.h | 34 ++-- include/tvm/runtime/c_runtime_api.h | 3 +- include/tvm/runtime/container.h | 20 ++ include/tvm/runtime/object.h | 16 ++ include/tvm/runtime/packed_func.h | 178 ++++++++++++++---- include/tvm/tir/expr.h | 26 +-- .../native/org_apache_tvm_native_c_api.cc | 3 +- python/tvm/_ffi/_ctypes/object.py | 3 + python/tvm/_ffi/_ctypes/packed_func.py | 5 +- python/tvm/_ffi/_ctypes/types.py | 4 +- python/tvm/_ffi/_cython/base.pxi | 3 +- python/tvm/_ffi/_cython/object.pxi | 5 +- python/tvm/_ffi/_cython/packed_func.pxi | 8 +- python/tvm/_ffi/runtime_ctypes.py | 15 ++ python/tvm/runtime/object.py | 31 +++ python/tvm/runtime/object_generic.py | 3 +- rust/frontend/src/function.rs | 2 +- src/ir/transform.cc | 2 +- src/node/container.cc | 1 - src/runtime/c_runtime_api.cc | 8 +- src/support/ffi_testing.cc | 10 +- tests/cpp/packed_func_test.cc | 47 ++++- .../unittest/test_runtime_packed_func.py | 31 +++ tests/python/unittest/test_runtime_rpc.py | 2 +- web/tvm_runtime.js | 10 +- 25 files changed, 374 insertions(+), 96 deletions(-) diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index 6822159cf119..13a699ae036f 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -123,7 +123,7 @@ class PrimExpr : public BaseExpr { private: // Internal function for conversion. - friend class runtime::TVMPODValue_; + friend struct runtime::PackedFuncValueConverter; TVM_DLL static PrimExpr FromObject_(ObjectPtr ptr); }; @@ -451,22 +451,24 @@ inline const TTypeNode* RelayExprNode::type_as() const { namespace tvm { namespace runtime { -// Additional implementattion overloads for PackedFunc. -inline TVMPODValue_::operator tvm::PrimExpr() const { - if (type_code_ == kTVMNullptr) return PrimExpr(); - if (type_code_ == kDLInt) { - CHECK_LE(value_.v_int64, std::numeric_limits::max()); - CHECK_GE(value_.v_int64, std::numeric_limits::min()); - return PrimExpr(static_cast(value_.v_int64)); +template<> +struct PackedFuncValueConverter { + // common rule for both RetValue and ArgValue. + static PrimExpr From(const TVMPODValue_& val) { + if (val.type_code() == kTVMNullptr) { + return PrimExpr(ObjectPtr(nullptr)); + } + if (val.type_code() == kDLInt) { + return PrimExpr(val.operator int()); + } + if (val.type_code() == kDLFloat) { + return PrimExpr(static_cast(val.operator double())); + } + TVM_CHECK_TYPE_CODE(val.type_code(), kTVMObjectHandle); + Object* ptr = val.ptr(); + return PrimExpr::FromObject_(GetObjectPtr(ptr)); } - if (type_code_ == kDLFloat) { - return PrimExpr(static_cast(value_.v_float64)); - } - - TVM_CHECK_TYPE_CODE(type_code_, kTVMObjectHandle); - Object* ptr = static_cast(value_.v_handle); - return PrimExpr::FromObject_(ObjectPtr(ptr)); -} +}; } // namespace runtime } // namespace tvm #endif // TVM_IR_EXPR_H_ diff --git a/include/tvm/runtime/c_runtime_api.h b/include/tvm/runtime/c_runtime_api.h index 28c390a025b3..920ecfbf9b13 100644 --- a/include/tvm/runtime/c_runtime_api.h +++ b/include/tvm/runtime/c_runtime_api.h @@ -104,6 +104,7 @@ typedef enum { kTVMStr = 11U, kTVMBytes = 12U, kTVMNDArrayHandle = 13U, + kTVMObjectRValueRefArg = 14U, // Extension codes for other frameworks to integrate TVM PackedFunc. // To make sure each framework's id do not conflict, use first and // last sections to mark ranges. @@ -290,7 +291,7 @@ TVM_DLL int TVMCFuncSetReturn(TVMRetValueHandle ret, * * \return 0 when success, -1 when failure happens. */ -TVM_DLL int TVMCbArgToReturn(TVMValue* value, int code); +TVM_DLL int TVMCbArgToReturn(TVMValue* value, int* code); /*! * \brief C type of packed function. diff --git a/include/tvm/runtime/container.h b/include/tvm/runtime/container.h index 083f87f89bc9..8963f0921276 100644 --- a/include/tvm/runtime/container.h +++ b/include/tvm/runtime/container.h @@ -27,6 +27,7 @@ #include #include #include +#include #include #include @@ -590,6 +591,25 @@ inline int String::memncmp(const char* lhs, const char* rhs, size_t lhs_count, } } +template<> +struct PackedFuncValueConverter<::tvm::runtime::String> { + static String From(const TVMArgValue& val) { + if (val.IsObjectRef()) { + return val.AsObjectRef(); + } else { + return tvm::runtime::String(val.operator std::string()); + } + } + + static String From(const TVMRetValue& val) { + if (val.IsObjectRef()) { + return val.AsObjectRef(); + } else { + return tvm::runtime::String(val.operator std::string()); + } + } +}; + } // namespace runtime } // namespace tvm diff --git a/include/tvm/runtime/object.h b/include/tvm/runtime/object.h index 717bf5e72fc4..8005cf61ff83 100644 --- a/include/tvm/runtime/object.h +++ b/include/tvm/runtime/object.h @@ -477,6 +477,17 @@ class ObjectPtr { data_->IncRef(); } } + /*! + * \brief Move an ObjectPtr from an RValueRef argument. + * \param ref The rvalue reference. + * \return the moved result. + */ + static ObjectPtr MoveFromRValueRefArg(Object** ref) { + ObjectPtr ptr; + ptr.data_ = *ref; + *ref = nullptr; + return ptr; + } // friend classes friend class Object; friend class ObjectRef; @@ -489,6 +500,7 @@ class ObjectPtr { friend class TVMArgsSetter; friend class TVMRetValue; friend class TVMArgValue; + friend class TVMMovableArgValue_; template friend RelayRefType GetRef(const ObjType* ptr); template @@ -550,6 +562,10 @@ class ObjectRef { bool unique() const { return data_.unique(); } + /*! \return The use count of the ptr, for debug purposes */ + int use_count() const { + return data_.use_count(); + } /*! * \brief Try to downcast the internal Object to a * raw pointer of a corresponding type. diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index 1b3ad571dbaa..6be461adabc4 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -30,7 +30,6 @@ #include #include #include -#include #include #include #include @@ -47,15 +46,12 @@ #endif namespace tvm { -// forward declarations -class Integer; -class PrimExpr; - namespace runtime { // forward declarations class TVMArgs; class TVMArgValue; +class TVMMovableArgValue_; class TVMRetValue; class TVMArgsSetter; @@ -210,6 +206,11 @@ class TypedPackedFunc { * \param value The TVMArgValue */ inline TypedPackedFunc(const TVMArgValue& value); // NOLINT(*) + /*! + * \brief constructor from TVMMovableArgValue_ + * \param value The TVMMovableArgValue_ + */ + inline TypedPackedFunc(TVMMovableArgValue_&& value); // NOLINT(*) /*! * \brief construct from a lambda function with the same signature. * @@ -386,8 +387,8 @@ class TVMPODValue_ { } operator int() const { TVM_CHECK_TYPE_CODE(type_code_, kDLInt); - CHECK_LE(value_.v_int64, - std::numeric_limits::max()); + CHECK_LE(value_.v_int64, std::numeric_limits::max()); + CHECK_GE(value_.v_int64, std::numeric_limits::min()); return static_cast(value_.v_int64); } operator bool() const { @@ -449,9 +450,6 @@ class TVMPODValue_ { inline bool IsObjectRef() const; template inline TObjectRef AsObjectRef() const; - // ObjectRef Specializations - inline operator tvm::PrimExpr() const; - inline operator tvm::Integer() const; protected: friend class TVMArgsSetter; @@ -497,8 +495,6 @@ class TVMArgValue : public TVMPODValue_ { using TVMPODValue_::operator Module; using TVMPODValue_::IsObjectRef; using TVMPODValue_::AsObjectRef; - using TVMPODValue_::operator tvm::PrimExpr; - using TVMPODValue_::operator tvm::Integer; // conversion operator. operator std::string() const { @@ -512,13 +508,6 @@ class TVMArgValue : public TVMPODValue_ { return std::string(value_.v_str); } } - operator tvm::runtime::String() const { - if (IsObjectRef()) { - return AsObjectRef(); - } else { - return tvm::runtime::String(operator std::string()); - } - } operator DLDataType() const { if (type_code_ == kTVMStr) { return String2DLDataType(operator std::string()); @@ -547,12 +536,52 @@ class TVMArgValue : public TVMPODValue_ { const TVMValue& value() const { return value_; } + template::value>::type> inline operator T() const; }; +/*! + * \brief Internal auxiliary struct for TypedPackedFunc to indicate a movable argument. + * + * We can only construct a movable argument once from a single argument position. + * If the argument is passed as RValue reference, the result will be moved. + * We should only construct a MovableArg from a argument once, + * as the result will can moved. + * + * \note For internal development purpose only. + */ +class TVMMovableArgValue_ : public TVMArgValue { + public: + TVMMovableArgValue_(TVMValue value, int type_code) + : TVMArgValue(value, type_code) { + } + // reuse converter from parent + using TVMArgValue::operator double; + using TVMArgValue::operator int64_t; + using TVMArgValue::operator uint64_t; + using TVMArgValue::operator int; + using TVMArgValue::operator bool; + using TVMArgValue::operator void*; + using TVMArgValue::operator DLTensor*; + using TVMArgValue::operator TVMContext; + using TVMArgValue::operator std::string; + using TVMArgValue::operator DLDataType; + using TVMArgValue::operator DataType; + using TVMArgValue::operator PackedFunc; + /*! + * \brief Helper converter function. + * Try to move out an argument if possible, + * fall back to normal argument conversion rule otherwise. + */ + template::value>::type> + inline operator T() const; +}; + /*! * \brief Return Value container, * Unlike TVMArgValue, which only holds reference and do not delete @@ -591,8 +620,6 @@ class TVMRetValue : public TVMPODValue_ { using TVMPODValue_::operator Module; using TVMPODValue_::IsObjectRef; using TVMPODValue_::AsObjectRef; - using TVMPODValue_::operator tvm::PrimExpr; - using TVMPODValue_::operator tvm::Integer; TVMRetValue(const TVMRetValue& other) : TVMPODValue_() { this->Assign(other); @@ -607,13 +634,6 @@ class TVMRetValue : public TVMPODValue_ { TVM_CHECK_TYPE_CODE(type_code_, kTVMStr); return *ptr(); } - operator tvm::runtime::String() const { - if (IsObjectRef()) { - return AsObjectRef(); - } else { - return tvm::runtime::String(operator std::string()); - } - } operator DLDataType() const { if (type_code_ == kTVMStr) { return String2DLDataType(operator std::string()); @@ -723,6 +743,10 @@ class TVMRetValue : public TVMPODValue_ { this->Assign(other); return *this; } + TVMRetValue& operator=(TVMMovableArgValue_&& other) { + this->Assign(other); + return *this; + } /*! * \brief Move the value back to front-end via C API. * This marks the current container as null. @@ -806,6 +830,10 @@ class TVMRetValue : public TVMPODValue_ { static_cast(other.value_.v_handle))); break; } + case kTVMObjectRValueRefArg: { + operator=(other.operator ObjectRef()); + break; + } default: { SwitchToPOD(other.type_code()); value_ = other.value_; @@ -863,6 +891,35 @@ class TVMRetValue : public TVMPODValue_ { } }; +/*! + * \brief Type trait to specify special value conversion rules from + * TVMArgValue and TVMRetValue. + * + * The trait can be specialized to add type specific conversion logic + * from the TVMArgvalue and TVMRetValue. + * + * \tparam TObjectRef the specific ObjectRefType. + */ +template +struct PackedFuncValueConverter { + /*! + * \brief Convert an TObjectRef from an argument value. + * \param val The argument value. + * \return the converted result. + */ + static TObjectRef From(const TVMArgValue& val) { + return val.AsObjectRef(); + } + /*! + * \brief Convert an TObjectRef from an argument value. + * \param val The argument value. + * \return the converted result. + */ + static TObjectRef From(const TVMRetValue& val) { + return val.AsObjectRef(); + } +}; + /*! * \brief Export a function with the PackedFunc signature * as a PackedFunc that can be loaded by LibraryModule. @@ -1132,10 +1189,24 @@ class TVMArgsSetter { // ObjectRef handling template::value>::type> - inline void operator()(size_t i, const TObjectRef& value) const; + std::is_base_of::value> + ::type> + void operator()(size_t i, const TObjectRef& value) const { + this->SetObject(i, value); + } + + template::type>::value> + ::type> + void operator()(size_t i, TObjectRef&& value) const { + this->SetObject(i, std::forward(value)); + } private: + template + inline void SetObject(size_t i, TObjectRef&& value) const; /*! \brief The values fields */ TVMValue* values_; /*! \brief The type code fields */ @@ -1163,10 +1234,13 @@ struct unpack_call_dispatcher { const TVMArgs& args_pack, TVMRetValue* rv, Args&&... unpacked_args) { + // construct a movable argument value + // which allows potential move of argument to the input of F. unpack_call_dispatcher ::run(f, args_pack, rv, std::forward(unpacked_args)..., - args_pack[index]); + TVMMovableArgValue_(args_pack.values[index], + args_pack.type_codes[index])); } }; @@ -1245,6 +1319,10 @@ template TypedPackedFunc::TypedPackedFunc(const TVMArgValue& value) : packed_(value.operator PackedFunc()) {} +template +TypedPackedFunc::TypedPackedFunc(TVMMovableArgValue_&& value) + : packed_(value.operator PackedFunc()) {} + template template inline void TypedPackedFunc::AssignTypedLambda(FType flambda) { @@ -1264,8 +1342,9 @@ inline R TypedPackedFunc::operator()(Args... args) const { // kTVMNDArrayHandle, kTVMModuleHandle, kTVMObjectHandle // // We use type traits to eliminate un-necessary checks. -template -inline void TVMArgsSetter::operator()(size_t i, const TObjectRef& value) const { +template +inline void TVMArgsSetter::SetObject(size_t i, T&& value) const { + using TObjectRef = typename std::remove_reference::type; if (value.defined()) { Object* ptr = value.data_.data_; if (std::is_base_of::value || @@ -1278,8 +1357,11 @@ inline void TVMArgsSetter::operator()(size_t i, const TObjectRef& value) const { ptr->IsInstance())) { values_[i].v_handle = ptr; type_codes_[i] = kTVMModuleHandle; + } else if (std::is_rvalue_reference::value) { + values_[i].v_handle = const_cast(&(value.data_.data_)); + type_codes_[i] = kTVMObjectRValueRefArg; } else { - values_[i].v_handle = ptr; + values_[i].v_handle = value.data_.data_; type_codes_[i] = kTVMObjectHandle; } } else { @@ -1300,6 +1382,11 @@ inline bool TVMPODValue_::IsObjectRef() const { return type_code_ == kTVMModuleHandle && static_cast(value_.v_handle)->IsInstance(); } + // NOTE: we don't pass NDArray and runtime::Module as RValue ref. + if (type_code_ == kTVMObjectRValueRefArg) { + return ObjectTypeChecker::Check( + *static_cast(value_.v_handle)); + } return (std::is_base_of::value && type_code_ == kTVMNDArrayHandle) || (std::is_base_of::value && type_code_ == kTVMModuleHandle) || @@ -1339,6 +1426,12 @@ inline TObjectRef TVMPODValue_::AsObjectRef() const { << "Expect " << ObjectTypeChecker::TypeName() << " but get " << ptr->GetTypeKey(); return TObjectRef(GetObjectPtr(ptr)); + } else if (type_code_ == kTVMObjectRValueRefArg) { + Object* ptr = *static_cast(value_.v_handle); + CHECK(ObjectTypeChecker::Check(ptr)) + << "Expect " << ObjectTypeChecker::TypeName() + << " but get " << ptr->GetTypeKey(); + return TObjectRef(GetObjectPtr(ptr)); } else if (std::is_base_of::value && type_code_ == kTVMNDArrayHandle) { // Casting to a base class that NDArray can sub-class @@ -1376,14 +1469,27 @@ inline TVMRetValue& TVMRetValue::operator=(TObjectRef other) { return *this; } + template inline TVMArgValue::operator T() const { - return AsObjectRef(); + return PackedFuncValueConverter::From(*this); +} + +template +inline TVMMovableArgValue_::operator T() const { + if (type_code_ == kTVMObjectRValueRefArg) { + auto** ref = static_cast(value_.v_handle); + if (ObjectTypeChecker::Check(*ref)) { + return T(ObjectPtr::MoveFromRValueRefArg(ref)); + } + } + // fallback + return PackedFuncValueConverter::From(*this); } template inline TVMRetValue::operator T() const { - return AsObjectRef(); + return PackedFuncValueConverter::From(*this); } inline PackedFunc Module::GetFunction(const std::string& name, bool query_imports) { diff --git a/include/tvm/tir/expr.h b/include/tvm/tir/expr.h index 6295a366c6ad..a1603d5e7bda 100644 --- a/include/tvm/tir/expr.h +++ b/include/tvm/tir/expr.h @@ -1338,20 +1338,20 @@ enum TVMStructFieldKind : int { namespace tvm { namespace runtime { // Additional implementattion overloads for PackedFunc. -inline TVMPODValue_::operator tvm::Integer() const { - if (type_code_ == kTVMNullptr) return Integer(); - if (type_code_ == kDLInt) { - CHECK_LE(value_.v_int64, std::numeric_limits::max()); - CHECK_GE(value_.v_int64, std::numeric_limits::min()); - return Integer(static_cast(value_.v_int64)); + +template<> +struct PackedFuncValueConverter { + // common rule for RetValue and ArgValue + static tvm::Integer From(const TVMPODValue_& val) { + if (val.type_code() == kTVMNullptr) { + return Integer(ObjectPtr(nullptr)); + } + if (val.type_code() == kDLInt) { + return Integer(val.operator int()); + } + return val.AsObjectRef(); } - TVM_CHECK_TYPE_CODE(type_code_, kTVMObjectHandle); - Object* ptr = static_cast(value_.v_handle); - CHECK(ObjectTypeChecker::Check(ptr)) - << "Expect type " << ObjectTypeChecker::TypeName() - << " but get " << ptr->GetTypeKey(); - return Integer(ObjectPtr(ptr)); -} +}; } // namespace runtime } // namespace tvm diff --git a/jvm/native/src/main/native/org_apache_tvm_native_c_api.cc b/jvm/native/src/main/native/org_apache_tvm_native_c_api.cc index e3446c93cc6f..b59956824d26 100644 --- a/jvm/native/src/main/native/org_apache_tvm_native_c_api.cc +++ b/jvm/native/src/main/native/org_apache_tvm_native_c_api.cc @@ -244,8 +244,9 @@ extern "C" int funcInvokeCallback(TVMValue *args, int tcode = typeCodes[i]; if (tcode == kTVMObjectHandle || tcode == kTVMPackedFuncHandle || + tcode == kTVMObjectRValueRefArg || tcode == kTVMModuleHandle) { - TVMCbArgToReturn(&arg, tcode); + TVMCbArgToReturn(&arg, &tcode); } jobject jarg = tvmRetValueToJava(env, arg, tcode); env->SetObjectArrayElement(jargs, i, jarg); diff --git a/python/tvm/_ffi/_ctypes/object.py b/python/tvm/_ffi/_ctypes/object.py index 263a76d414a8..b5dc65fd5e79 100644 --- a/python/tvm/_ffi/_ctypes/object.py +++ b/python/tvm/_ffi/_ctypes/object.py @@ -60,6 +60,9 @@ def _return_object(x): C_TO_PY_ARG_SWITCH[TypeCode.OBJECT_HANDLE] = _wrap_arg_func( _return_object, TypeCode.OBJECT_HANDLE) +C_TO_PY_ARG_SWITCH[TypeCode.OBJECT_RVALUE_REF_ARG] = _wrap_arg_func( + _return_object, TypeCode.OBJECT_RVALUE_REF_ARG) + class ObjectBase(object): """Base object for all object types""" diff --git a/python/tvm/_ffi/_ctypes/packed_func.py b/python/tvm/_ffi/_ctypes/packed_func.py index 6e1dbf55f618..11bb65504c61 100644 --- a/python/tvm/_ffi/_ctypes/packed_func.py +++ b/python/tvm/_ffi/_ctypes/packed_func.py @@ -23,7 +23,7 @@ from ..base import _LIB, get_last_ffi_error, py2cerror, check_call from ..base import c_str, string_types -from ..runtime_ctypes import DataType, TVMByteArray, TVMContext +from ..runtime_ctypes import DataType, TVMByteArray, TVMContext, ObjectRValueRef from . import ndarray as _nd from .ndarray import NDArrayBase, _make_array from .types import TVMValue, TypeCode @@ -164,6 +164,9 @@ def _make_tvm_args(args, temp_args): elif isinstance(arg, ctypes.c_void_p): values[i].v_handle = arg type_codes[i] = TypeCode.HANDLE + elif isinstance(arg, ObjectRValueRef): + values[i].v_handle = ctypes.cast(ctypes.byref(arg.obj.handle), ctypes.c_void_p) + type_codes[i] = TypeCode.OBJECT_RVALUE_REF_ARG elif callable(arg): arg = convert_to_tvm_func(arg) values[i].v_handle = arg.handle diff --git a/python/tvm/_ffi/_ctypes/types.py b/python/tvm/_ffi/_ctypes/types.py index f45748fdd4de..20be30a59b2f 100644 --- a/python/tvm/_ffi/_ctypes/types.py +++ b/python/tvm/_ffi/_ctypes/types.py @@ -73,9 +73,9 @@ def _return_context(value): def _wrap_arg_func(return_f, type_code): - tcode = ctypes.c_int(type_code) def _wrap_func(x): - check_call(_LIB.TVMCbArgToReturn(ctypes.byref(x), tcode)) + tcode = ctypes.c_int(type_code) + check_call(_LIB.TVMCbArgToReturn(ctypes.byref(x), ctypes.byref(tcode))) return return_f(x) return _wrap_func diff --git a/python/tvm/_ffi/_cython/base.pxi b/python/tvm/_ffi/_cython/base.pxi index ad281d7512d9..0da66ac2e034 100644 --- a/python/tvm/_ffi/_cython/base.pxi +++ b/python/tvm/_ffi/_cython/base.pxi @@ -37,6 +37,7 @@ cdef enum TVMTypeCode: kTVMStr = 11 kTVMBytes = 12 kTVMNDArrayHandle = 13 + kTVMObjectRefArg = 14 kTVMExtBegin = 15 cdef extern from "tvm/runtime/c_runtime_api.h": @@ -113,7 +114,7 @@ cdef extern from "tvm/runtime/c_runtime_api.h": void* resource_handle, TVMPackedCFuncFinalizer fin, TVMPackedFuncHandle *out) - int TVMCbArgToReturn(TVMValue* value, int code) + int TVMCbArgToReturn(TVMValue* value, int* code) int TVMArrayAlloc(tvm_index_t* shape, tvm_index_t ndim, DLDataType dtype, diff --git a/python/tvm/_ffi/_cython/object.pxi b/python/tvm/_ffi/_cython/object.pxi index 2a345cad684e..f2b5cc172d45 100644 --- a/python/tvm/_ffi/_cython/object.pxi +++ b/python/tvm/_ffi/_cython/object.pxi @@ -64,10 +64,7 @@ cdef class ObjectBase: property handle: def __get__(self): - if self.chandle == NULL: - return None - else: - return ctypes_handle(self.chandle) + return ctypes_handle(self.chandle) def __set__(self, value): self._set_handle(value) diff --git a/python/tvm/_ffi/_cython/packed_func.pxi b/python/tvm/_ffi/_cython/packed_func.pxi index 9d13dbd290d6..6977e108bf88 100644 --- a/python/tvm/_ffi/_cython/packed_func.pxi +++ b/python/tvm/_ffi/_cython/packed_func.pxi @@ -20,7 +20,7 @@ import traceback from cpython cimport Py_INCREF, Py_DECREF from numbers import Number, Integral from ..base import string_types, py2cerror -from ..runtime_ctypes import DataType, TVMContext, TVMByteArray +from ..runtime_ctypes import DataType, TVMContext, TVMByteArray, ObjectRValueRef cdef void tvm_callback_finalize(void* fhandle): @@ -43,8 +43,9 @@ cdef int tvm_callback(TVMValue* args, if (tcode == kTVMObjectHandle or tcode == kTVMPackedFuncHandle or tcode == kTVMModuleHandle or + tcode == kTVMObjectRefArg or tcode > kTVMExtBegin): - CALL(TVMCbArgToReturn(&value, tcode)) + CALL(TVMCbArgToReturn(&value, &tcode)) if tcode != kTVMDLTensorHandle: pyargs.append(make_ret(value, tcode)) @@ -167,6 +168,9 @@ cdef inline int make_arg(object arg, elif isinstance(arg, ctypes.c_void_p): value[0].v_handle = c_handle(arg) tcode[0] = kTVMOpaqueHandle + elif isinstance(arg, ObjectRValueRef): + value[0].v_handle = &(((arg.obj)).chandle) + tcode[0] = kTVMObjectRefArg elif callable(arg): arg = convert_to_tvm_func(arg) value[0].v_handle = (arg).chandle diff --git a/python/tvm/_ffi/runtime_ctypes.py b/python/tvm/_ffi/runtime_ctypes.py index 160cc3ec9e21..6b06ad01c9ff 100644 --- a/python/tvm/_ffi/runtime_ctypes.py +++ b/python/tvm/_ffi/runtime_ctypes.py @@ -39,6 +39,7 @@ class TypeCode(object): STR = 11 BYTES = 12 NDARRAY_HANDLE = 13 + OBJECT_RVALUE_REF_ARG = 14 EXT_BEGIN = 15 @@ -281,4 +282,18 @@ class TVMArray(ctypes.Structure): ("strides", ctypes.POINTER(tvm_shape_index_t)), ("byte_offset", ctypes.c_uint64)] + +class ObjectRValueRef: + """Represent an RValue ref to an object that can be moved. + + Parameters + ---------- + obj : tvm.runtime.Object + The object that this value refers to + """ + __slots__ = ["obj"] + def __init__(self, obj): + self.obj = obj + + TVMArrayHandle = ctypes.POINTER(TVMArray) diff --git a/python/tvm/runtime/object.py b/python/tvm/runtime/object.py index f725c1908ba5..a55eeb0cb3ee 100644 --- a/python/tvm/runtime/object.py +++ b/python/tvm/runtime/object.py @@ -19,6 +19,7 @@ import ctypes from tvm._ffi.base import _FFI_MODE, _RUNTIME_ONLY, check_call, _LIB, c_str +from tvm._ffi.runtime_ctypes import ObjectRValueRef from . import _ffi_api, _ffi_node_api try: @@ -85,5 +86,35 @@ def __setstate__(self, state): else: self.handle = None + def _move(self): + """Create an RValue reference to the object and mark the object as moved. + + This is a advanced developer API that can be useful when passing an + unique reference to an Object that you no longer needed to a function. + + A unique reference can trigger copy on write optimization that avoids + copy when we transform an object. + + Note + ---- + All the reference of the object becomes invalid after it is moved. + Be very careful when using this feature. + + Examples + -------- + + .. code-block:: python + + x = tvm.tir.Var("x", "int32") + x0 = x + some_packed_func(x._move()) + # both x0 and x will points to None after the function call. + + Returns + ------- + rvalue : The rvalue reference. + """ + return ObjectRValueRef(self) + _set_class_object(Object) diff --git a/python/tvm/runtime/object_generic.py b/python/tvm/runtime/object_generic.py index a7716df83189..ac20b67e8299 100644 --- a/python/tvm/runtime/object_generic.py +++ b/python/tvm/runtime/object_generic.py @@ -18,6 +18,7 @@ # pylint: disable=unused-import, invalid-name from numbers import Number, Integral from tvm._ffi.base import string_types +from tvm._ffi.runtime_ctypes import ObjectRValueRef from . import _ffi_node_api, _ffi_api from .object import ObjectBase, _set_class_object_generic @@ -33,7 +34,7 @@ def asobject(self): raise NotImplementedError() -ObjectTypes = (ObjectBase, NDArrayBase, Module) +ObjectTypes = (ObjectBase, NDArrayBase, Module, ObjectRValueRef) def convert_to_object(value): diff --git a/rust/frontend/src/function.rs b/rust/frontend/src/function.rs index 7f055259fe0c..d9c0e5c1aa8b 100644 --- a/rust/frontend/src/function.rs +++ b/rust/frontend/src/function.rs @@ -261,7 +261,7 @@ unsafe extern "C" fn tvm_callback( || tcode == ffi::TVMTypeCode_kTVMPackedFuncHandle as c_int || tcode == ffi::TVMTypeCode_kTVMModuleHandle as c_int { - check_call!(ffi::TVMCbArgToReturn(&mut value as *mut _, tcode)); + check_call!(ffi::TVMCbArgToReturn(&mut value as *mut _, &mut tcode as *mut _)); } local_args.push(TVMArgValue::from_tvm_value(value, tcode as u32)); } diff --git a/src/ir/transform.cc b/src/ir/transform.cc index 6e38aac92ec0..49c0ef414d4f 100644 --- a/src/ir/transform.cc +++ b/src/ir/transform.cc @@ -371,7 +371,7 @@ TVM_REGISTER_GLOBAL("transform.MakeModulePass") .set_body_typed( [](runtime::TypedPackedFunc pass_func, PassInfo pass_info) { - return ModulePass(pass_func, pass_info); + return ModulePass(pass_func, pass_info); }); TVM_REGISTER_GLOBAL("transform.RunPass") diff --git a/src/node/container.cc b/src/node/container.cc index bce2eeea0ec9..0949da0eade2 100644 --- a/src/node/container.cc +++ b/src/node/container.cc @@ -370,7 +370,6 @@ TVM_REGISTER_GLOBAL("node.MapGetItem") Object* ptr = static_cast(args[0].value().v_handle); if (ptr->IsInstance()) { - CHECK(args[1].type_code() == kTVMObjectHandle); auto* n = static_cast(ptr); auto it = n->data.find(args[1].operator ObjectRef()); CHECK(it != n->data.end()) diff --git a/src/runtime/c_runtime_api.cc b/src/runtime/c_runtime_api.cc index 6a50bae2dbe9..fb1f74da2103 100644 --- a/src/runtime/c_runtime_api.cc +++ b/src/runtime/c_runtime_api.cc @@ -577,13 +577,11 @@ int TVMStreamStreamSynchronize(int device_type, API_END(); } -int TVMCbArgToReturn(TVMValue* value, int code) { +int TVMCbArgToReturn(TVMValue* value, int* code) { API_BEGIN(); tvm::runtime::TVMRetValue rv; - rv = tvm::runtime::TVMArgValue(*value, code); - int tcode; - rv.MoveToCHost(value, &tcode); - CHECK_EQ(tcode, code); + rv = tvm::runtime::TVMMovableArgValue_(*value, *code); + rv.MoveToCHost(value, code); API_END(); } diff --git a/src/support/ffi_testing.cc b/src/support/ffi_testing.cc index 9053f6298999..90fcfff0eef3 100644 --- a/src/support/ffi_testing.cc +++ b/src/support/ffi_testing.cc @@ -107,11 +107,11 @@ TVM_REGISTER_GLOBAL("testing.ErrorTest") .set_body_typed(ErrorTest); // internal function used for debug and testing purposes -TVM_REGISTER_GLOBAL("testing.ndarray_use_count") +TVM_REGISTER_GLOBAL("testing.object_use_count") .set_body([](TVMArgs args, TVMRetValue *ret) { - runtime::NDArray nd = args[0]; - // substract the current one - *ret = (nd.use_count() - 1); + runtime::ObjectRef obj = args[0]; + // substract the current one because we always copy + // and get another value. + *ret = (obj.use_count() - 1); }); - } // namespace tvm diff --git a/tests/cpp/packed_func_test.cc b/tests/cpp/packed_func_test.cc index d0313c60d984..99b6ca25a162 100644 --- a/tests/cpp/packed_func_test.cc +++ b/tests/cpp/packed_func_test.cc @@ -20,6 +20,7 @@ #include #include #include +#include #include #include @@ -51,7 +52,7 @@ TEST(PackedFunc, Node) { Var x; Var t = PackedFunc([&](TVMArgs args, TVMRetValue* rv) { CHECK(args.num_args == 1); - CHECK(args.type_codes[0] == kTVMObjectHandle); + CHECK(args[0].IsObjectRef()); Var b = args[0]; CHECK(x.same_as(b)); *rv = b; @@ -269,6 +270,50 @@ TEST(PackedFunc, ObjectConversion) { pf2(ObjectRef(m), Module()); } +TEST(TypedPackedFunc, RValue) { + using namespace tvm; + using namespace tvm::runtime; + { + auto f = [](tir::Var x, bool move) { + if (move) { + CHECK(x.unique()); + } else { + CHECK(!x.unique()); + } + CHECK(x->name_hint == "x"); + return x; + }; + TypedPackedFunc tf(f); + + tir::Var var("x"); + CHECK(var.unique()); + f(var, false); + // move the result to the function. + tir::Var ret = f(std::move(var), true); + CHECK(!var.defined()); + } + + { + // pass child class. + auto f = [](PrimExpr x, bool move) { + if (move) { + CHECK(x.unique()); + } else { + CHECK(!x.unique()); + } + return x; + }; + TypedPackedFunc tf(f); + + tir::Var var("x"); + CHECK(var.unique()); + f(var, false); + f(std::move(var), true); + // auto conversion. + f(1, true); + } +} + int main(int argc, char ** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe"; diff --git a/tests/python/unittest/test_runtime_packed_func.py b/tests/python/unittest/test_runtime_packed_func.py index 3570fe149608..fcaceb3dcaeb 100644 --- a/tests/python/unittest/test_runtime_packed_func.py +++ b/tests/python/unittest/test_runtime_packed_func.py @@ -98,6 +98,33 @@ def test_ctx_func(ctx): x = tvm.testing.context_test(x, x.device_type, x.device_id) assert x == tvm.opencl(10) + +def test_rvalue_ref(): + def callback(x, expected_count): + assert expected_count == tvm.testing.object_use_count(x) + return x + + f = tvm.runtime.convert(callback) + + def check0(): + x = tvm.tir.Var("x", "int32") + assert tvm.testing.object_use_count(x) == 1 + f(x, 2) + y = f(x._move(), 1) + assert x.handle.value == None + + def check1(): + x = tvm.tir.Var("x", "int32") + assert tvm.testing.object_use_count(x) == 1 + y = f(x, 2) + z = f(x._move(), 2) + assert x.handle.value == None + assert y.handle.value is not None + + check0() + check1() + + def test_trace_default_action(): n = 2 x = te.placeholder((n,n,n), name="X", dtype="float32") @@ -269,7 +296,11 @@ def check_assign(dtype): for t in ["float64", "float32"]: check_assign(t) + + if __name__ == "__main__": + test_rvalue_ref() + exit(0) test_empty_array() test_get_global() test_get_callback_with_node() diff --git a/tests/python/unittest/test_runtime_rpc.py b/tests/python/unittest/test_runtime_rpc.py index 1d9b79eca875..b61e6bb9fa01 100644 --- a/tests/python/unittest/test_runtime_rpc.py +++ b/tests/python/unittest/test_runtime_rpc.py @@ -212,7 +212,7 @@ def my_module(name): if name == "get_arr": return lambda : nd elif name == "ref_count": - return lambda : tvm.testing.ndarray_use_count(nd) + return lambda : tvm.testing.object_use_count(nd) elif name == "get_elem": return lambda idx: nd.asnumpy()[idx] elif name == "get_arr_elem": diff --git a/web/tvm_runtime.js b/web/tvm_runtime.js index 0740efc5ff6a..b62b298d969e 100644 --- a/web/tvm_runtime.js +++ b/web/tvm_runtime.js @@ -105,6 +105,7 @@ var tvm_runtime = tvm_runtime || {}; var kTVMPackedFuncHandle = 10; var kTVMStr = 11; var kTVMBytes = 12; + var kTVMObjectRValueRefArg = 14; //----------------------------------------- // TVM CWrap library // ---------------------------------------- @@ -171,7 +172,7 @@ var tvm_runtime = tvm_runtime || {}; ("TVMCbArgToReturn", "number", ["number", // TVMValue* value - "number" // int code + "number" // int* code ]); var TVMFuncCreateFromCFunc = Module.cwrap @@ -496,12 +497,15 @@ var tvm_runtime = tvm_runtime || {}; var args = []; for (var i = 0; i < nargs; ++i) { var vptr = arg_value + i * SIZEOF_TVMVALUE; - var tcode = Module.getValue(arg_tcode + i * SIZEOF_INT, "i32"); + var tcodeptr = arg_tcode + i * SIZEOF_INT; + var tcode = Module.getValue(tcodeptr, "i32"); if (tcode == kTVMObjectHandle || + tcode == kTVMObjectRValueRefArg || tcode == kTVMPackedFuncHandle || tcode == kTVMModuleHandle) { - TVM_CALL(TVMCbArgToReturn(vptr, tcode)); + TVM_CALL(TVMCbArgToReturn(vptr, tcodeptr)); } + tcode = Module.getValue(tcodeptr, "i32"); args.push(TVMRetValueToJS(vptr, tcode)); } var rv = funcTable[handle].apply(null, args); From 58f61a34e2fbea1c86f2530a43bfb2e6b4e310cc Mon Sep 17 00:00:00 2001 From: tqchen Date: Fri, 10 Apr 2020 11:01:23 -0700 Subject: [PATCH 2/3] Address review comments --- include/tvm/runtime/packed_func.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index 6be461adabc4..2dcb4ffcfa79 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -548,7 +548,7 @@ class TVMArgValue : public TVMPODValue_ { * * We can only construct a movable argument once from a single argument position. * If the argument is passed as RValue reference, the result will be moved. - * We should only construct a MovableArg from a argument once, + * We should only construct a MovableArg from an argument once, * as the result will can moved. * * \note For internal development purpose only. @@ -903,7 +903,7 @@ class TVMRetValue : public TVMPODValue_ { template struct PackedFuncValueConverter { /*! - * \brief Convert an TObjectRef from an argument value. + * \brief Convert a TObjectRef from an argument value. * \param val The argument value. * \return the converted result. */ @@ -911,7 +911,7 @@ struct PackedFuncValueConverter { return val.AsObjectRef(); } /*! - * \brief Convert an TObjectRef from an argument value. + * \brief Convert a TObjectRef from a return value. * \param val The argument value. * \return the converted result. */ From 21b0fb54b1faad43d229e9a5696764a55276e0f6 Mon Sep 17 00:00:00 2001 From: tqchen Date: Fri, 10 Apr 2020 14:28:09 -0700 Subject: [PATCH 3/3] fix compilation --- include/tvm/node/container.h | 1 + 1 file changed, 1 insertion(+) diff --git a/include/tvm/node/container.h b/include/tvm/node/container.h index cf2ac260d674..ba1edf84383e 100644 --- a/include/tvm/node/container.h +++ b/include/tvm/node/container.h @@ -26,6 +26,7 @@ #include #include #include +#include #include #include