From 4ae96a747c9a0bda9069dfb7c18d254c30b36623 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 24 Jun 2024 11:14:10 -0500 Subject: [PATCH 1/8] [Bugfix] Use OBJECT_REF macros to provide ContainerType Prior to this commit, the `tvm::relax::PatternContext`, `tvm::relay::AnnotatedRegionSet`, and `tvm::relay::CallGraph` classes explicitly defined constructors from `ObjectPtr`, and `operator->` implementations. However, they did not define `ContainerType`. As a result, any use of `T::ContainerType` (e.g. in `Downcast(obj)` or `obj.as()`) would incorrectly use the inherited `ObjectRef::ContainerType`. Since these downcast methods are erroneously comparing against `ObjectRef::ContainerType`, the type checks prior to downcasting were effectively disabled for these classes. This commit updates the `tvm::relax::PatternContext`, `tvm::relay::AnnotatedRegionSet`, and `tvm::relay::CallGraph` classes to use the `TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS` and `TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS` macros, which provide all required members for the TVM `ObjectRef` interface. --- include/tvm/relax/dataflow_pattern.h | 13 ++----------- src/relay/analysis/annotated_region_set.h | 17 +++-------------- src/relay/analysis/call_graph.h | 13 +------------ 3 files changed, 6 insertions(+), 37 deletions(-) diff --git a/include/tvm/relax/dataflow_pattern.h b/include/tvm/relax/dataflow_pattern.h index f7094b221221..28aef60b80cb 100644 --- a/include/tvm/relax/dataflow_pattern.h +++ b/include/tvm/relax/dataflow_pattern.h @@ -277,19 +277,8 @@ class PatternContextNode : public Object { */ class PatternContext : public ObjectRef { public: - TVM_DLL explicit PatternContext(ObjectPtr n) : ObjectRef(n) {} TVM_DLL explicit PatternContext(bool incremental = false); - const PatternContextNode* operator->() const { - ICHECK(get() != nullptr); - return static_cast(get()); - } - - PatternContextNode* operator->() { - ICHECK(get() != nullptr); - return static_cast(get_mutable()); - } - /*! * \brief Build an edge constraint between two patterns (producer and consumer). * @@ -333,6 +322,8 @@ class PatternContext : public ObjectRef { /*! \brief The RAII-like exit of a constraint context scope */ TVM_DLL void ExitWithScope() const; + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(PatternContext, ObjectRef, PatternContextNode); + private: friend class With; }; diff --git a/src/relay/analysis/annotated_region_set.h b/src/relay/analysis/annotated_region_set.h index 443bd5ec1da3..4c5ec1cfdf37 100644 --- a/src/relay/analysis/annotated_region_set.h +++ b/src/relay/analysis/annotated_region_set.h @@ -210,13 +210,6 @@ class AnnotatedRegionSet : public ObjectRef { data_ = std::move(n); } - /*! - * \brief Construct from an object pointer. - * - * \param n The object pointer. - */ - explicit AnnotatedRegionSet(ObjectPtr n) : ObjectRef(n) {} - /*! \return The begin iterator. */ iterator begin() { auto* n = operator->(); @@ -242,13 +235,6 @@ class AnnotatedRegionSet : public ObjectRef { return n->end(); } - /*! \return mutable pointers to the node. */ - AnnotatedRegionSetNode* operator->() const { - auto* ptr = get_mutable(); - ICHECK(ptr != nullptr); - return static_cast(ptr); - } - /*! \return The region an expression belongs to. */ AnnotatedRegion operator[](const Expr& expr) { const auto* n = operator->(); @@ -268,6 +254,9 @@ class AnnotatedRegionSet : public ObjectRef { static AnnotatedRegionSet Create(const Expr& expr, const Op& begin, const Op& end, const std::string& func_name = "default"); + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(AnnotatedRegionSet, ObjectRef, + AnnotatedRegionSetNode); + private: /*! \brief Helper class to construct a RegionSet from an expr.*/ class Creator; diff --git a/src/relay/analysis/call_graph.h b/src/relay/analysis/call_graph.h index 091891acd414..54ed00868360 100644 --- a/src/relay/analysis/call_graph.h +++ b/src/relay/analysis/call_graph.h @@ -207,12 +207,6 @@ class CallGraph : public ObjectRef { */ explicit CallGraph(IRModule module); - /*! - * \brief Construct from an object pointer. - * \param n The object pointer. - */ - explicit CallGraph(ObjectPtr n) : ObjectRef(n) {} - /*! \return The begin iterator. */ iterator begin() { auto* n = operator->(); @@ -287,12 +281,7 @@ class CallGraph : public ObjectRef { return (*n)[gvar_name]; } - /*! \return mutable pointers to the node. */ - CallGraphNode* operator->() const { - auto* ptr = get_mutable(); - ICHECK(ptr != nullptr); - return static_cast(ptr); - } + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(CallGraph, ObjectRef, CallGraphNode); private: /*! \brief Overload the << operator to print a call graph. */ From 17bc5388f302bed7236254ae1fe6bdf7b905fe65 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Sat, 15 Jun 2024 17:59:28 -0500 Subject: [PATCH 2/8] [TIR][FFI] Allow string in PrimFunc return statement Prior to this PR, a TIR PrimFunc could return an `int64`, `float64`, or void. Returning any other type, even if supported by TVM's FFI, would raise an exception during `MakePackedAPI`. With this PR, string literals can be returned from TIR PrimFuncs. Enabling this TIR functionality requires a change to the TVM FFI. Previously, a `TVMRetValue` holding a string would allocate a `std::string*` and store it to `value_.v_handle`. This return value depends on the C++ runtime, and cannot be constructed on all backends that only require the C API. With this commit, this is instead represented as a `const char*` in `value_.v_str`, matching the representation in `TVMArgValue`. Unlike `TVMArgValue`, since p`TVMRetValue` may own its return value, two new member variables `void (*f_deleter_)(void*)` and `void* f_deleter_arg_`. These are designed to be easy to call from a C API, and are set to a deleter function that will free the allocation (if required), and an argument to be passed into the deleter function. With this new representation, the return value for different use cases would be set as follows: * Return string from C++ PackedFunc. The backing allocation is made using `new std::string(std::move(value))`. This heap allocation is stored in `f_deleter_arg_`, so that `f_deleter_` can then call `delete static_cast(arg)`. Finally, `new_str->data()` is stored as `value.v_str`. * Return pointer to static string from C. A pointer to the static string is stored in `v_str`. Both `f_deleter_` and `f_deleter_arg_` are left with their default value of `nullptr`, indicating that the `TVMRetValue` does not need to free any memory on destruction. This is the return value used by `T.ret(T.StringImm("my_string"))` in a TIR PrimFunc. * Return pointer to heap allocation made in C. Assuming the allocation is made using `malloc()`, a pointer to the heap allocation is stored in both `value_.v_str` and `f_deleter_arg_`. The `f_deleter_` is set to `free`, indicating that `free(f_deleter_arg_)` should be called when the `TVMRetValue` is destructed. This functionality is possibly within the updated `TVMRetValue` API, but is not currently used. This commit adds unit tests for the return of strings from compiled TIR PrimFuncs, when either running the function locally, or over a RPC connection. --- include/tvm/ir/op.h | 2 +- include/tvm/runtime/packed_func.h | 205 +++++++++++------- python/tvm/tir/op.py | 5 +- src/node/attr_registry.h | 2 +- src/runtime/c_runtime_api.cc | 6 +- src/runtime/library_module.cc | 12 +- src/tir/transforms/make_packed_api.cc | 3 + .../codegen/test_target_codegen_llvm.py | 11 + tests/python/runtime/test_runtime_rpc.py | 22 ++ .../tvmscript/test_tvmscript_roundtrip.py | 16 ++ 10 files changed, 185 insertions(+), 99 deletions(-) diff --git a/include/tvm/ir/op.h b/include/tvm/ir/op.h index 6e6b8bee5fc3..d6f73e3c3142 100644 --- a/include/tvm/ir/op.h +++ b/include/tvm/ir/op.h @@ -466,7 +466,7 @@ inline OpRegEntry& OpRegEntry::set_attr( // NOLINT(*) ICHECK_GT(plevel, 0) << "plevel in set_attr must be greater than 0"; runtime::TVMRetValue rv; rv = value; - UpdateAttr(attr_name, rv, plevel); + UpdateAttr(attr_name, std::move(rv), plevel); return *this; } diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index 7266f8c4a50a..37a0816504a9 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -585,6 +585,31 @@ class TVMPODValue_ { TVM_CHECK_TYPE_CODE(type_code_, kTVMOpaqueHandle); return value_.v_handle; } + operator std::string() const { + if (type_code_ == kTVMDataType) { + return DLDataType2String(operator DLDataType()); + } else if (type_code_ == kTVMBytes) { + TVMByteArray* arr = static_cast(value_.v_handle); + return std::string(arr->data, arr->size); + } else if (type_code_ == kTVMStr) { + return std::string(value_.v_str); + } else { + return AsObjectRef().operator std::string(); + } + } + operator TVMByteArray() const { + if (type_code_ == kTVMBytes) { + return *static_cast(value_.v_handle); + } else { + LOG(FATAL) << "Expected " + << "TVMByteArray but got " << ArgTypeCode2Str(type_code_); + } + } + + inline operator DLDataType() const; + + operator DataType() const { return DataType(operator DLDataType()); } + operator DLTensor*() const { if (type_code_ == kTVMDLTensorHandle || type_code_ == kTVMNDArrayHandle) { return static_cast(value_.v_handle); @@ -639,13 +664,16 @@ class TVMPODValue_ { friend class TVMArgsSetter; friend class TVMRetValue; friend class TVMMovableArgValue_; - TVMPODValue_() : type_code_(kTVMNullptr) {} + friend PackedFunc WrapPackedFunc(int (*)(TVMValue*, int*, int, TVMValue*, int*, void*), + const ObjectPtr&); + + TVMPODValue_() : value_{0}, type_code_(kTVMNullptr) {} TVMPODValue_(TVMValue value, int type_code) : value_(value), type_code_(type_code) {} /*! \brief The value */ - TVMValue value_; + TVMValue value_{0}; /*! \brief the type code */ - int type_code_; + int type_code_{kTVMNullptr}; }; /*! @@ -671,6 +699,10 @@ class TVMArgValue : public TVMPODValue_ { using TVMPODValue_::operator int; using TVMPODValue_::operator bool; using TVMPODValue_::operator void*; + using TVMPODValue_::operator std::string; + using TVMPODValue_::operator TVMByteArray; + using TVMPODValue_::operator DLDataType; + using TVMPODValue_::operator DataType; using TVMPODValue_::operator DLTensor*; using TVMPODValue_::operator NDArray; using TVMPODValue_::operator Device; @@ -680,18 +712,6 @@ class TVMArgValue : public TVMPODValue_ { using TVMPODValue_::IsObjectRef; // conversion operator. - operator std::string() const { - if (type_code_ == kTVMDataType) { - return DLDataType2String(operator DLDataType()); - } else if (type_code_ == kTVMBytes) { - TVMByteArray* arr = static_cast(value_.v_handle); - return std::string(arr->data, arr->size); - } else if (type_code_ == kTVMStr) { - return std::string(value_.v_str); - } else { - return AsObjectRef().operator std::string(); - } - } template operator TypedPackedFunc() const { return TypedPackedFunc(operator PackedFunc()); @@ -700,8 +720,6 @@ class TVMArgValue : public TVMPODValue_ { template ::value>::type> inline operator T() const; - inline operator DLDataType() const; - inline operator DataType() const; }; /*! @@ -724,6 +742,10 @@ class TVMMovableArgValue_ : public TVMPODValue_ { using TVMPODValue_::operator int; using TVMPODValue_::operator bool; using TVMPODValue_::operator void*; + using TVMPODValue_::operator std::string; + using TVMPODValue_::operator TVMByteArray; + using TVMPODValue_::operator DLDataType; + using TVMPODValue_::operator DataType; using TVMPODValue_::operator DLTensor*; using TVMPODValue_::operator NDArray; using TVMPODValue_::operator Device; @@ -735,8 +757,6 @@ class TVMMovableArgValue_ : public TVMPODValue_ { operator TypedPackedFunc() const { return TypedPackedFunc(operator PackedFunc()); } - operator DLDataType() const { return AsArgValue().operator DLDataType(); } - operator DataType() const { return AsArgValue().operator DataType(); } operator TVMArgValue() const { return AsArgValue(); } /*! * \brief Helper converter function. @@ -807,17 +827,19 @@ class TVMMovableArgValueWithContext_ { class TVMRetValue : public TVMPODValue_ { public: /*! \brief default constructor */ - TVMRetValue() {} + TVMRetValue() : f_deleter_(nullptr), f_deleter_arg_(nullptr) {} + /*! * \brief move constructor from another return value. * \param other The other return value. */ - TVMRetValue(TVMRetValue&& other) : TVMPODValue_(other.value_, other.type_code_) { - other.value_.v_handle = nullptr; - other.type_code_ = kTVMNullptr; - } + TVMRetValue(TVMRetValue&& other) : TVMRetValue() { *this = std::move(other); } + + TVMRetValue(const TVMRetValue& other) : TVMRetValue() { *this = other; } + /*! \brief destructor */ ~TVMRetValue() { this->Clear(); } + // reuse converter from parent using TVMPODValue_::operator double; using TVMPODValue_::operator int64_t; @@ -825,6 +847,10 @@ class TVMRetValue : public TVMPODValue_ { using TVMPODValue_::operator int; using TVMPODValue_::operator bool; using TVMPODValue_::operator void*; + using TVMPODValue_::operator std::string; + using TVMPODValue_::operator TVMByteArray; + using TVMPODValue_::operator DLDataType; + using TVMPODValue_::operator DataType; using TVMPODValue_::operator DLTensor*; using TVMPODValue_::operator Device; using TVMPODValue_::operator NDArray; @@ -833,35 +859,17 @@ class TVMRetValue : public TVMPODValue_ { using TVMPODValue_::AsObjectRef; using TVMPODValue_::IsObjectRef; - TVMRetValue(const TVMRetValue& other) : TVMPODValue_() { this->Assign(other); } // conversion operators - operator std::string() const { - if (type_code_ == kTVMDataType) { - return DLDataType2String(operator DLDataType()); - } else if (type_code_ == kTVMBytes) { - return *ptr(); - } - TVM_CHECK_TYPE_CODE(type_code_, kTVMStr); - return *ptr(); - } - operator DLDataType() const { - if (type_code_ == kTVMStr) { - return String2DLDataType(operator std::string()); - } - TVM_CHECK_TYPE_CODE(type_code_, kTVMDataType); - return value_.v_type; - } - operator DataType() const { return DataType(operator DLDataType()); } template operator TypedPackedFunc() const { return TypedPackedFunc(operator PackedFunc()); } // Assign operators TVMRetValue& operator=(TVMRetValue&& other) { - this->Clear(); - value_ = other.value_; - type_code_ = other.type_code_; - other.type_code_ = kTVMNullptr; + std::swap(value_, other.value_); + std::swap(type_code_, other.type_code_); + std::swap(f_deleter_, other.f_deleter_); + std::swap(f_deleter_arg_, other.f_deleter_arg_); return *this; } TVMRetValue& operator=(double value) { @@ -906,11 +914,49 @@ class TVMRetValue : public TVMPODValue_ { return *this; } TVMRetValue& operator=(std::string value) { - this->SwitchToClass(kTVMStr, value); + this->Clear(); + + std::string* container = new std::string(std::move(value)); + f_deleter_ = [](void* arg) { delete static_cast(arg); }; + f_deleter_arg_ = container; + + type_code_ = kTVMStr; + value_.v_str = container->c_str(); + return *this; } TVMRetValue& operator=(TVMByteArray value) { - this->SwitchToClass(kTVMBytes, std::string(value.data, value.size)); + this->Clear(); + + /* \brief Container for owned data + * + * For consistency with TVMArgValue, kTVMBytes should store a + * `TVMByteArray*` in `value_.v_handle`. However, `TVMRetValue` + * must own its backing allocation, where `TVMByteArray` does not + * own the data to which it points. + * + * This struct provides both ownership over an allocation, and a + * `TVMByteArray` with a view into the owned allocation. + */ + struct OwnedArray { + explicit OwnedArray(std::vector arg) + : backing_vector(std::move(arg)), array{backing_vector.data(), backing_vector.size()} {} + OwnedArray(const OwnedArray&) = delete; + + // The backing allocation + std::vector backing_vector; + + // The TVMByteArray, referencing the backing allocation + TVMByteArray array; + }; + + OwnedArray* container = new OwnedArray(std::vector(value.data, value.data + value.size)); + f_deleter_ = [](void* arg) { delete static_cast(arg); }; + f_deleter_arg_ = container; + + type_code_ = kTVMBytes; + value_.v_handle = &container->array; + return *this; } TVMRetValue& operator=(NDArray other) { @@ -999,11 +1045,11 @@ class TVMRetValue : public TVMPODValue_ { void Assign(const T& other) { switch (other.type_code()) { case kTVMStr: { - SwitchToClass(kTVMStr, other); + *this = other.operator std::string(); break; } case kTVMBytes: { - SwitchToClass(kTVMBytes, other); + *this = other.operator TVMByteArray(); break; } case kTVMPackedFuncHandle: { @@ -1042,16 +1088,6 @@ class TVMRetValue : public TVMPODValue_ { type_code_ = type_code; } } - template - void SwitchToClass(int type_code, T v) { - if (type_code_ != type_code) { - this->Clear(); - type_code_ = type_code; - value_.v_handle = new T(v); - } else { - *static_cast(value_.v_handle) = v; - } - } void SwitchToObject(int type_code, ObjectPtr other) { if (other.data_ != nullptr) { this->Clear(); @@ -1066,29 +1102,46 @@ class TVMRetValue : public TVMPODValue_ { } void Clear() { if (type_code_ == kTVMNullptr) return; + switch (type_code_) { case kTVMStr: - case kTVMBytes: - delete ptr(); + case kTVMBytes: { + if (f_deleter_) { + (*f_deleter_)(f_deleter_arg_); + } break; - case kTVMPackedFuncHandle: + } + + case kTVMPackedFuncHandle: { static_cast(value_.v_handle)->DecRef(); break; + } + case kTVMNDArrayHandle: { NDArray::FFIDecRef(static_cast(value_.v_handle)); break; } + case kTVMModuleHandle: { static_cast(value_.v_handle)->DecRef(); break; } + case kTVMObjectHandle: { static_cast(value_.v_handle)->DecRef(); break; } } type_code_ = kTVMNullptr; + f_deleter_ = nullptr; + f_deleter_arg_ = nullptr; } + + /* \brief The deleter function to call for owned values */ + void (*f_deleter_)(void*) = nullptr; + + /* \brief The argument to be provided to the deleter function */ + void* f_deleter_arg_ = nullptr; }; /*! @@ -1740,14 +1793,8 @@ class TVMArgsSetter { operator()(i, value.packed()); } void operator()(size_t i, const TVMRetValue& value) const { - if (value.type_code() == kTVMStr) { - values_[i].v_str = value.ptr()->c_str(); - type_codes_[i] = kTVMStr; - } else { - ICHECK_NE(value.type_code(), kTVMBytes) << "not handled."; - values_[i] = value.value_; - type_codes_[i] = value.type_code(); - } + values_[i] = value.value_; + type_codes_[i] = value.type_code(); } // ObjectRef handling template struct PackedFuncValueConverter<::tvm::runtime::String> { - static String From(const TVMArgValue& val) { - if (val.IsObjectRef()) { - return val.AsObjectRef(); - } else { - return tvm::runtime::String(val.operator std::string()); - } - } - - static String From(const TVMRetValue& val) { + static String From(const TVMPODValue_& val) { if (val.IsObjectRef()) { return val.AsObjectRef(); } else { @@ -2222,8 +2261,8 @@ inline bool String::CanConvertFrom(const TVMArgValue& val) { return val.type_code() == kTVMStr || val.IsObjectRef(); } -inline TVMArgValue::operator DLDataType() const { - if (String::CanConvertFrom(*this)) { +inline TVMPODValue_::operator DLDataType() const { + if (type_code_ == kTVMStr || IsObjectRef()) { return String2DLDataType(PackedFuncValueConverter::From(*this).operator std::string()); } // None type @@ -2238,8 +2277,6 @@ inline TVMArgValue::operator DLDataType() const { return value_.v_type; } -inline TVMArgValue::operator DataType() const { return DataType(operator DLDataType()); } - } // namespace runtime // NOLINT(*) } // namespace tvm // NOLINT(*) #endif // TVM_RUNTIME_PACKED_FUNC_H_ diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index 81d6604259a3..901b653e1d45 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -1827,7 +1827,10 @@ def ret(val): The return expression """ - val = convert(val) + if isinstance(val, str): + val = StringImm(val) + else: + val = convert(val) return call_intrin(val.dtype, "tir.ret", val) diff --git a/src/node/attr_registry.h b/src/node/attr_registry.h index 050f9e5b2845..309ccf94eed1 100644 --- a/src/node/attr_registry.h +++ b/src/node/attr_registry.h @@ -113,7 +113,7 @@ class AttrRegistry { ICHECK(value.type_code() != kTVMNullptr) << "Registered packed_func is Null for " << attr_name << " of operator " << key->AttrRegistryName(); if (p.second < plevel && value.type_code() != kTVMNullptr) { - op_map->data_[index] = std::make_pair(value, plevel); + op_map->data_[index] = std::make_pair(std::move(value), plevel); } } diff --git a/src/runtime/c_runtime_api.cc b/src/runtime/c_runtime_api.cc index ea22b89dd771..699273f54d10 100644 --- a/src/runtime/c_runtime_api.cc +++ b/src/runtime/c_runtime_api.cc @@ -575,11 +575,7 @@ int TVMFuncCall(TVMFunctionHandle func, TVMValue* args, int* arg_type_codes, int // handle return string. if (rv.type_code() == kTVMStr || rv.type_code() == kTVMDataType || rv.type_code() == kTVMBytes) { TVMRuntimeEntry* e = TVMAPIRuntimeStore::Get(); - if (rv.type_code() != kTVMDataType) { - e->ret_str = *rv.ptr(); - } else { - e->ret_str = rv.operator std::string(); - } + e->ret_str = rv.operator std::string(); if (rv.type_code() == kTVMBytes) { e->ret_bytes.data = e->ret_str.c_str(); e->ret_bytes.size = e->ret_str.length(); diff --git a/src/runtime/library_module.cc b/src/runtime/library_module.cc index 7b39bcd8da02..67b18f5bd709 100644 --- a/src/runtime/library_module.cc +++ b/src/runtime/library_module.cc @@ -69,21 +69,19 @@ class LibraryModuleNode final : public ModuleNode { PackedFunc WrapPackedFunc(TVMBackendPackedCFunc faddr, const ObjectPtr& sptr_to_self) { return PackedFunc([faddr, sptr_to_self](TVMArgs args, TVMRetValue* rv) { - TVMValue ret_value; - int ret_type_code = kTVMNullptr; + TVMRetValue ret_value; + auto arg_values = const_cast(args.values); auto arg_type_codes = const_cast(args.type_codes); - int ret = - (*faddr)(arg_values, arg_type_codes, args.num_args, &ret_value, &ret_type_code, nullptr); + int ret = (*faddr)(arg_values, arg_type_codes, args.num_args, &ret_value.value_, + &ret_value.type_code_, nullptr); // NOTE: It is important to keep the original error message. // Using the `TVMThrowLastError()` function will also preserve the // full stack trace for debugging in pdb. if (ret != 0) { TVMThrowLastError(); } - if (ret_type_code != kTVMNullptr) { - *rv = TVMRetValue::MoveFromCHost(ret_value, ret_type_code); - } + *rv = std::move(ret_value); }); } diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index d327cdfa8393..234e034b0342 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -90,6 +90,9 @@ class ReturnRewriter : public StmtMutator { } else if (dtype.is_void()) { info.tcode = kTVMNullptr; info.expr = val; + } else if (val->IsInstance()) { + info.tcode = kTVMStr; + info.expr = val; } else { LOG(FATAL) << "data type " << dtype << " not supported yet"; } diff --git a/tests/python/codegen/test_target_codegen_llvm.py b/tests/python/codegen/test_target_codegen_llvm.py index f50d63878e4f..0e2013642e6f 100644 --- a/tests/python/codegen/test_target_codegen_llvm.py +++ b/tests/python/codegen/test_target_codegen_llvm.py @@ -1138,5 +1138,16 @@ def func(): tvm.build(func) +def test_return_string_from_tir(): + @T.prim_func + def func(): + return "hello!" + + built = tvm.build(func, target="llvm") + + out = built() + assert out == "hello!" + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/runtime/test_runtime_rpc.py b/tests/python/runtime/test_runtime_rpc.py index fbdc33928b6e..d9d9dbc7924a 100644 --- a/tests/python/runtime/test_runtime_rpc.py +++ b/tests/python/runtime/test_runtime_rpc.py @@ -734,3 +734,25 @@ def func_with_arg(unused: T.int64) -> T.int64: res = remote_mod["func_without_arg"]() assert res == 42 + + +def test_return_string_over_rpc(): + @T.prim_func + def func(unused: T.int64) -> T.handle: + return T.StringImm("hello!") + + built = tvm.build(func, target="llvm") + + server = tvm.rpc.Server(key="x1") + client = tvm.rpc.connect("127.0.0.1", server.port, key="x1") + + libname = "libbuilt.so" + with tempfile.TemporaryDirectory(prefix="tvm_rpc_testing_") as temp_dir: + local_path = os.path.join(temp_dir, libname) + built.export_library(local_path) + client.upload(local_path) + + remote_mod = client.load_module(libname) + + out = remote_mod(42) + assert out == "hello!" diff --git a/tests/python/tvmscript/test_tvmscript_roundtrip.py b/tests/python/tvmscript/test_tvmscript_roundtrip.py index f81a80de6d61..2fdfdfc65af5 100644 --- a/tests/python/tvmscript/test_tvmscript_roundtrip.py +++ b/tests/python/tvmscript/test_tvmscript_roundtrip.py @@ -4114,6 +4114,21 @@ def func(A: R.Tensor(["N"], "float16"), _: R.Prim(value="threshold")): return func +def tir_return_string_imm(): + """TIR StringImm must round-trip + + The conversion from Python str to TIR StringImm occurs at the + callee. + + """ + + @T.prim_func + def func(): + return T.StringImm("hello") + + return func + + ir_generator = tvm.testing.parameter( launch_env_thread, opt_gemm_normalize, @@ -4202,6 +4217,7 @@ def func(A: R.Tensor(["N"], "float16"), _: R.Prim(value="threshold")): *relax_match_cast_struct_info_proxy(), relax_symbolic_size_var, relax_float_symbolic_var, + tir_return_string_imm, ) relax_ir_generator = tvm.testing.parameter( From d1f4c59018e8196d8e8844cfd77a40239e13ae6e Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 24 Jun 2024 12:39:25 -0500 Subject: [PATCH 3/8] [FFI] Convert tvm.runtime.String to python-native str Prior to this commit, a C++ `tvm::runtime::String` returned by a PackedFunc would be converted to a `tvm.runtime.String` by the Python FFI. There are two main issues with this handling. 1. Unnecessary duplication of memory. The python-native `str` constructed with `str.__new__` owns its underlying data. The argument passed to `runtime.GetFFIString` also owns its underlying data. Because the `GetFFIString` argument is saved to `val.__tvm_object__`, the python object returned from `String.__from_tvm_object__` keeps both copies alive. 2. Potential dangling references. The C++ `tvm::runtime::String` may have a deleter that is implemented in a dynamically-loaded LLVM or `.so` module. Removing the unnecessary copy and retaining the python-native `str` avoids this issue. This commit updates the Python implementation of `tvm.runtime.String.__from_tvm_object__` to return a python-native `str`. --- python/tvm/runtime/container.py | 10 ++++------ python/tvm/target/target.py | 4 ++-- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/python/tvm/runtime/container.py b/python/tvm/runtime/container.py index 686b4a26c80c..d23df9d41841 100644 --- a/python/tvm/runtime/container.py +++ b/python/tvm/runtime/container.py @@ -131,12 +131,10 @@ def __new__(cls, content): return val # pylint: disable=no-self-argument - def __from_tvm_object__(cls, obj): - """Construct from a given tvm object.""" - content = _ffi_api.GetFFIString(obj) - val = str.__new__(cls, content) - val.__tvm_object__ = obj - return val + def __from_tvm_object__(cls, obj: Object) -> str: + """Convert from runtime.String to native string""" + + return _ffi_api.GetFFIString(obj) @tvm._ffi.register_object("runtime.ShapeTuple") diff --git a/python/tvm/target/target.py b/python/tvm/target/target.py index ec74cbcdb62a..2882499fe44f 100644 --- a/python/tvm/target/target.py +++ b/python/tvm/target/target.py @@ -128,10 +128,10 @@ def __init__(self, target, host=None): target = convert(target) if isinstance(host, (dict, str)): host = convert(host) - if target is None or not isinstance(target, (Map, String, Target)): + if target is None or not isinstance(target, (Map, str, Target)): raise ValueError("target has to be a string or dictionary.") if host is not None: - if not isinstance(host, (Map, String, Target)): + if not isinstance(host, (Map, str, Target)): raise ValueError("target host has to be a string or dictionary.") self.__init_handle_by_constructor__(_ffi_api.Target, Target(target), Target(host)) else: From 6c363e95763bc473705a97089f78ac33f531bc60 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 18 Jun 2024 13:19:22 -0500 Subject: [PATCH 4/8] Re-implement to generate runtime::StringObj --- include/tvm/ir/op.h | 2 +- include/tvm/runtime/packed_func.h | 233 ++++++++---------- include/tvm/tir/builtin.h | 9 + src/node/attr_registry.h | 2 +- src/runtime/c_runtime_api.cc | 6 +- src/runtime/library_module.cc | 12 +- src/target/llvm/codegen_cpu.cc | 91 +++++++ src/target/llvm/codegen_cpu.h | 5 + src/tir/op/builtin.cc | 4 + src/tir/transforms/make_packed_api.cc | 6 +- .../codegen/test_target_codegen_llvm.py | 16 +- 11 files changed, 248 insertions(+), 138 deletions(-) diff --git a/include/tvm/ir/op.h b/include/tvm/ir/op.h index d6f73e3c3142..6e6b8bee5fc3 100644 --- a/include/tvm/ir/op.h +++ b/include/tvm/ir/op.h @@ -466,7 +466,7 @@ inline OpRegEntry& OpRegEntry::set_attr( // NOLINT(*) ICHECK_GT(plevel, 0) << "plevel in set_attr must be greater than 0"; runtime::TVMRetValue rv; rv = value; - UpdateAttr(attr_name, std::move(rv), plevel); + UpdateAttr(attr_name, rv, plevel); return *this; } diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index 37a0816504a9..6a3ffd591195 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -27,6 +27,7 @@ #include #include #include +#include #include #include #include @@ -585,31 +586,6 @@ class TVMPODValue_ { TVM_CHECK_TYPE_CODE(type_code_, kTVMOpaqueHandle); return value_.v_handle; } - operator std::string() const { - if (type_code_ == kTVMDataType) { - return DLDataType2String(operator DLDataType()); - } else if (type_code_ == kTVMBytes) { - TVMByteArray* arr = static_cast(value_.v_handle); - return std::string(arr->data, arr->size); - } else if (type_code_ == kTVMStr) { - return std::string(value_.v_str); - } else { - return AsObjectRef().operator std::string(); - } - } - operator TVMByteArray() const { - if (type_code_ == kTVMBytes) { - return *static_cast(value_.v_handle); - } else { - LOG(FATAL) << "Expected " - << "TVMByteArray but got " << ArgTypeCode2Str(type_code_); - } - } - - inline operator DLDataType() const; - - operator DataType() const { return DataType(operator DLDataType()); } - operator DLTensor*() const { if (type_code_ == kTVMDLTensorHandle || type_code_ == kTVMNDArrayHandle) { return static_cast(value_.v_handle); @@ -664,16 +640,13 @@ class TVMPODValue_ { friend class TVMArgsSetter; friend class TVMRetValue; friend class TVMMovableArgValue_; - friend PackedFunc WrapPackedFunc(int (*)(TVMValue*, int*, int, TVMValue*, int*, void*), - const ObjectPtr&); - - TVMPODValue_() : value_{0}, type_code_(kTVMNullptr) {} + TVMPODValue_() : type_code_(kTVMNullptr) {} TVMPODValue_(TVMValue value, int type_code) : value_(value), type_code_(type_code) {} /*! \brief The value */ - TVMValue value_{0}; + TVMValue value_; /*! \brief the type code */ - int type_code_{kTVMNullptr}; + int type_code_; }; /*! @@ -699,10 +672,6 @@ class TVMArgValue : public TVMPODValue_ { using TVMPODValue_::operator int; using TVMPODValue_::operator bool; using TVMPODValue_::operator void*; - using TVMPODValue_::operator std::string; - using TVMPODValue_::operator TVMByteArray; - using TVMPODValue_::operator DLDataType; - using TVMPODValue_::operator DataType; using TVMPODValue_::operator DLTensor*; using TVMPODValue_::operator NDArray; using TVMPODValue_::operator Device; @@ -712,6 +681,18 @@ class TVMArgValue : public TVMPODValue_ { using TVMPODValue_::IsObjectRef; // conversion operator. + operator std::string() const { + if (type_code_ == kTVMDataType) { + return DLDataType2String(operator DLDataType()); + } else if (type_code_ == kTVMBytes) { + TVMByteArray* arr = static_cast(value_.v_handle); + return std::string(arr->data, arr->size); + } else if (type_code_ == kTVMStr) { + return std::string(value_.v_str); + } else { + return AsObjectRef().operator std::string(); + } + } template operator TypedPackedFunc() const { return TypedPackedFunc(operator PackedFunc()); @@ -720,6 +701,8 @@ class TVMArgValue : public TVMPODValue_ { template ::value>::type> inline operator T() const; + inline operator DLDataType() const; + inline operator DataType() const; }; /*! @@ -742,10 +725,6 @@ class TVMMovableArgValue_ : public TVMPODValue_ { using TVMPODValue_::operator int; using TVMPODValue_::operator bool; using TVMPODValue_::operator void*; - using TVMPODValue_::operator std::string; - using TVMPODValue_::operator TVMByteArray; - using TVMPODValue_::operator DLDataType; - using TVMPODValue_::operator DataType; using TVMPODValue_::operator DLTensor*; using TVMPODValue_::operator NDArray; using TVMPODValue_::operator Device; @@ -757,6 +736,8 @@ class TVMMovableArgValue_ : public TVMPODValue_ { operator TypedPackedFunc() const { return TypedPackedFunc(operator PackedFunc()); } + operator DLDataType() const { return AsArgValue().operator DLDataType(); } + operator DataType() const { return AsArgValue().operator DataType(); } operator TVMArgValue() const { return AsArgValue(); } /*! * \brief Helper converter function. @@ -827,19 +808,17 @@ class TVMMovableArgValueWithContext_ { class TVMRetValue : public TVMPODValue_ { public: /*! \brief default constructor */ - TVMRetValue() : f_deleter_(nullptr), f_deleter_arg_(nullptr) {} - + TVMRetValue() {} /*! * \brief move constructor from another return value. * \param other The other return value. */ - TVMRetValue(TVMRetValue&& other) : TVMRetValue() { *this = std::move(other); } - - TVMRetValue(const TVMRetValue& other) : TVMRetValue() { *this = other; } - + TVMRetValue(TVMRetValue&& other) : TVMPODValue_(other.value_, other.type_code_) { + other.value_.v_handle = nullptr; + other.type_code_ = kTVMNullptr; + } /*! \brief destructor */ ~TVMRetValue() { this->Clear(); } - // reuse converter from parent using TVMPODValue_::operator double; using TVMPODValue_::operator int64_t; @@ -847,10 +826,6 @@ class TVMRetValue : public TVMPODValue_ { using TVMPODValue_::operator int; using TVMPODValue_::operator bool; using TVMPODValue_::operator void*; - using TVMPODValue_::operator std::string; - using TVMPODValue_::operator TVMByteArray; - using TVMPODValue_::operator DLDataType; - using TVMPODValue_::operator DataType; using TVMPODValue_::operator DLTensor*; using TVMPODValue_::operator Device; using TVMPODValue_::operator NDArray; @@ -859,17 +834,35 @@ class TVMRetValue : public TVMPODValue_ { using TVMPODValue_::AsObjectRef; using TVMPODValue_::IsObjectRef; + TVMRetValue(const TVMRetValue& other) : TVMPODValue_() { this->Assign(other); } // conversion operators + operator std::string() const { + if (type_code_ == kTVMDataType) { + return DLDataType2String(operator DLDataType()); + } else if (type_code_ == kTVMBytes) { + return *ptr(); + } + TVM_CHECK_TYPE_CODE(type_code_, kTVMStr); + return *ptr(); + } + operator DLDataType() const { + if (type_code_ == kTVMStr) { + return String2DLDataType(operator std::string()); + } + TVM_CHECK_TYPE_CODE(type_code_, kTVMDataType); + return value_.v_type; + } + operator DataType() const { return DataType(operator DLDataType()); } template operator TypedPackedFunc() const { return TypedPackedFunc(operator PackedFunc()); } // Assign operators TVMRetValue& operator=(TVMRetValue&& other) { - std::swap(value_, other.value_); - std::swap(type_code_, other.type_code_); - std::swap(f_deleter_, other.f_deleter_); - std::swap(f_deleter_arg_, other.f_deleter_arg_); + this->Clear(); + value_ = other.value_; + type_code_ = other.type_code_; + other.type_code_ = kTVMNullptr; return *this; } TVMRetValue& operator=(double value) { @@ -914,49 +907,11 @@ class TVMRetValue : public TVMPODValue_ { return *this; } TVMRetValue& operator=(std::string value) { - this->Clear(); - - std::string* container = new std::string(std::move(value)); - f_deleter_ = [](void* arg) { delete static_cast(arg); }; - f_deleter_arg_ = container; - - type_code_ = kTVMStr; - value_.v_str = container->c_str(); - + this->SwitchToClass(kTVMStr, value); return *this; } TVMRetValue& operator=(TVMByteArray value) { - this->Clear(); - - /* \brief Container for owned data - * - * For consistency with TVMArgValue, kTVMBytes should store a - * `TVMByteArray*` in `value_.v_handle`. However, `TVMRetValue` - * must own its backing allocation, where `TVMByteArray` does not - * own the data to which it points. - * - * This struct provides both ownership over an allocation, and a - * `TVMByteArray` with a view into the owned allocation. - */ - struct OwnedArray { - explicit OwnedArray(std::vector arg) - : backing_vector(std::move(arg)), array{backing_vector.data(), backing_vector.size()} {} - OwnedArray(const OwnedArray&) = delete; - - // The backing allocation - std::vector backing_vector; - - // The TVMByteArray, referencing the backing allocation - TVMByteArray array; - }; - - OwnedArray* container = new OwnedArray(std::vector(value.data, value.data + value.size)); - f_deleter_ = [](void* arg) { delete static_cast(arg); }; - f_deleter_arg_ = container; - - type_code_ = kTVMBytes; - value_.v_handle = &container->array; - + this->SwitchToClass(kTVMBytes, std::string(value.data, value.size)); return *this; } TVMRetValue& operator=(NDArray other) { @@ -1021,9 +976,19 @@ class TVMRetValue : public TVMPODValue_ { static TVMRetValue MoveFromCHost(TVMValue value, int type_code) { // Can move POD and everything under the object system. ICHECK(type_code <= kTVMPackedFuncHandle || type_code == kTVMNDArrayHandle); + TVMRetValue ret; ret.value_ = value; ret.type_code_ = type_code; + + if (ret.type_code_ == kTVMObjectHandle) { + // A C implementation may not have performed the same type + // normalization is performed through the C++ API. For example, + // constructing a `tvm::runtime::String` rather than the + // implementation-dependent `std::string`. + ret = ret.operator ObjectRef(); + } + return ret; } /*! \return The value field, if the data is POD */ @@ -1045,11 +1010,11 @@ class TVMRetValue : public TVMPODValue_ { void Assign(const T& other) { switch (other.type_code()) { case kTVMStr: { - *this = other.operator std::string(); + SwitchToClass(kTVMStr, other); break; } case kTVMBytes: { - *this = other.operator TVMByteArray(); + SwitchToClass(kTVMBytes, other); break; } case kTVMPackedFuncHandle: { @@ -1088,6 +1053,16 @@ class TVMRetValue : public TVMPODValue_ { type_code_ = type_code; } } + template + void SwitchToClass(int type_code, T v) { + if (type_code_ != type_code) { + this->Clear(); + type_code_ = type_code; + value_.v_handle = new T(v); + } else { + *static_cast(value_.v_handle) = v; + } + } void SwitchToObject(int type_code, ObjectPtr other) { if (other.data_ != nullptr) { this->Clear(); @@ -1102,46 +1077,29 @@ class TVMRetValue : public TVMPODValue_ { } void Clear() { if (type_code_ == kTVMNullptr) return; - switch (type_code_) { case kTVMStr: - case kTVMBytes: { - if (f_deleter_) { - (*f_deleter_)(f_deleter_arg_); - } + case kTVMBytes: + delete ptr(); break; - } - - case kTVMPackedFuncHandle: { + case kTVMPackedFuncHandle: static_cast(value_.v_handle)->DecRef(); break; - } - case kTVMNDArrayHandle: { NDArray::FFIDecRef(static_cast(value_.v_handle)); break; } - case kTVMModuleHandle: { static_cast(value_.v_handle)->DecRef(); break; } - case kTVMObjectHandle: { static_cast(value_.v_handle)->DecRef(); break; } } type_code_ = kTVMNullptr; - f_deleter_ = nullptr; - f_deleter_arg_ = nullptr; } - - /* \brief The deleter function to call for owned values */ - void (*f_deleter_)(void*) = nullptr; - - /* \brief The argument to be provided to the deleter function */ - void* f_deleter_arg_ = nullptr; }; /*! @@ -1793,8 +1751,14 @@ class TVMArgsSetter { operator()(i, value.packed()); } void operator()(size_t i, const TVMRetValue& value) const { - values_[i] = value.value_; - type_codes_[i] = value.type_code(); + if (value.type_code() == kTVMStr) { + values_[i].v_str = value.ptr()->c_str(); + type_codes_[i] = kTVMStr; + } else { + ICHECK_NE(value.type_code(), kTVMBytes) << "not handled."; + values_[i] = value.value_; + type_codes_[i] = value.type_code(); + } } // ObjectRef handling template (static_cast(value_.v_handle))); - } else { - TVM_CHECK_TYPE_CODE(type_code_, kTVMObjectHandle); - return TObjectRef(ObjectPtr(nullptr)); + } else if constexpr (std::is_base_of::value) { + if (type_code_ == kTVMStr) { + return runtime::String(value_.v_str); + } } + + TVM_CHECK_TYPE_CODE(type_code_, kTVMObjectHandle); + return TObjectRef(ObjectPtr(nullptr)); } template @@ -2149,6 +2117,13 @@ inline TVMRetValue& TVMRetValue::operator=(TObjectRef other) { ptr->IsInstance())) { return operator=(PackedFunc(std::move(other.data_))); } + if (std::is_base_of::value || + (std::is_base_of::value && + ptr->IsInstance())) { + const auto* string_obj = other.template as(); + return operator=(std::string(string_obj->data, string_obj->size)); + } + SwitchToObject(kTVMObjectHandle, std::move(other.data_)); } else { SwitchToPOD(kTVMNullptr); @@ -2186,7 +2161,15 @@ inline PackedFunc Module::GetFunction(const String& name, bool query_imports) { // specializations of PackedFuncValueConverter template <> struct PackedFuncValueConverter<::tvm::runtime::String> { - static String From(const TVMPODValue_& val) { + static String From(const TVMArgValue& val) { + if (val.IsObjectRef()) { + return val.AsObjectRef(); + } else { + return tvm::runtime::String(val.operator std::string()); + } + } + + static String From(const TVMRetValue& val) { if (val.IsObjectRef()) { return val.AsObjectRef(); } else { @@ -2261,8 +2244,8 @@ inline bool String::CanConvertFrom(const TVMArgValue& val) { return val.type_code() == kTVMStr || val.IsObjectRef(); } -inline TVMPODValue_::operator DLDataType() const { - if (type_code_ == kTVMStr || IsObjectRef()) { +inline TVMArgValue::operator DLDataType() const { + if (String::CanConvertFrom(*this)) { return String2DLDataType(PackedFuncValueConverter::From(*this).operator std::string()); } // None type @@ -2277,6 +2260,8 @@ inline TVMPODValue_::operator DLDataType() const { return value_.v_type; } +inline TVMArgValue::operator DataType() const { return DataType(operator DLDataType()); } + } // namespace runtime // NOLINT(*) } // namespace tvm // NOLINT(*) #endif // TVM_RUNTIME_PACKED_FUNC_H_ diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h index 120c1b71be72..b8e0ad6a1b0a 100644 --- a/include/tvm/tir/builtin.h +++ b/include/tvm/tir/builtin.h @@ -283,6 +283,15 @@ TVM_DLL const Op& tvm_struct_get(); */ TVM_DLL const Op& tvm_struct_set(); +/*! + * \brief TIR constructor for tvm::runtime::StringObj + * + * runtime::StringObj* tvm_string_obj(StringImm value) { + * return new StringObj(value); + * } + */ +TVM_DLL const Op& tvm_string_obj(); + /*! * \brief See pseudo code * Type lookup_param(String param_name) { diff --git a/src/node/attr_registry.h b/src/node/attr_registry.h index 309ccf94eed1..050f9e5b2845 100644 --- a/src/node/attr_registry.h +++ b/src/node/attr_registry.h @@ -113,7 +113,7 @@ class AttrRegistry { ICHECK(value.type_code() != kTVMNullptr) << "Registered packed_func is Null for " << attr_name << " of operator " << key->AttrRegistryName(); if (p.second < plevel && value.type_code() != kTVMNullptr) { - op_map->data_[index] = std::make_pair(std::move(value), plevel); + op_map->data_[index] = std::make_pair(value, plevel); } } diff --git a/src/runtime/c_runtime_api.cc b/src/runtime/c_runtime_api.cc index 699273f54d10..ea22b89dd771 100644 --- a/src/runtime/c_runtime_api.cc +++ b/src/runtime/c_runtime_api.cc @@ -575,7 +575,11 @@ int TVMFuncCall(TVMFunctionHandle func, TVMValue* args, int* arg_type_codes, int // handle return string. if (rv.type_code() == kTVMStr || rv.type_code() == kTVMDataType || rv.type_code() == kTVMBytes) { TVMRuntimeEntry* e = TVMAPIRuntimeStore::Get(); - e->ret_str = rv.operator std::string(); + if (rv.type_code() != kTVMDataType) { + e->ret_str = *rv.ptr(); + } else { + e->ret_str = rv.operator std::string(); + } if (rv.type_code() == kTVMBytes) { e->ret_bytes.data = e->ret_str.c_str(); e->ret_bytes.size = e->ret_str.length(); diff --git a/src/runtime/library_module.cc b/src/runtime/library_module.cc index 67b18f5bd709..7b39bcd8da02 100644 --- a/src/runtime/library_module.cc +++ b/src/runtime/library_module.cc @@ -69,19 +69,21 @@ class LibraryModuleNode final : public ModuleNode { PackedFunc WrapPackedFunc(TVMBackendPackedCFunc faddr, const ObjectPtr& sptr_to_self) { return PackedFunc([faddr, sptr_to_self](TVMArgs args, TVMRetValue* rv) { - TVMRetValue ret_value; - + TVMValue ret_value; + int ret_type_code = kTVMNullptr; auto arg_values = const_cast(args.values); auto arg_type_codes = const_cast(args.type_codes); - int ret = (*faddr)(arg_values, arg_type_codes, args.num_args, &ret_value.value_, - &ret_value.type_code_, nullptr); + int ret = + (*faddr)(arg_values, arg_type_codes, args.num_args, &ret_value, &ret_type_code, nullptr); // NOTE: It is important to keep the original error message. // Using the `TVMThrowLastError()` function will also preserve the // full stack trace for debugging in pdb. if (ret != 0) { TVMThrowLastError(); } - *rv = std::move(ret_value); + if (ret_type_code != kTVMNullptr) { + *rv = TVMRetValue::MoveFromCHost(ret_value, ret_type_code); + } }); } diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index 481ba39cc7b1..f21399a111be 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -399,6 +399,93 @@ llvm::Value* CodeGenCPU::CreateCallExtern(Type ret_type, String global_symbol, return builder_->CreateCall(ftype, callee, arg_values); } +llvm::Value* CodeGenCPU::CreateStringObj(StringImm ir_string) { + auto t_deleter = llvm::FunctionType::get(t_void_, {t_void_p_}, false); + + if (!t_tvm_base_object_) { + t_tvm_base_object_ = llvm::StructType::create( + { + t_int32_ /* type_index_ */, + t_int32_ /* ref_counter_ */, + t_deleter->getPointerTo() /* deleter_ */, + }, + "tvm::runtime::Object", true); + } + if (!t_tvm_string_obj_) { + t_tvm_string_obj_ = llvm::StructType::create( + { + t_tvm_base_object_, + t_char_->getPointerTo() /* data */, + t_int64_ /* size */, + }, + "tvm::runtime::StringObj"); + } + if (!f_string_obj_deleter_) { + auto prev_insert_point = builder_->GetInsertBlock(); + + f_string_obj_deleter_ = llvm::Function::Create(t_deleter, llvm::Function::PrivateLinkage, + "string_obj_deleter", module_.get()); + + llvm::LLVMContext* ctx = llvm_target_->GetContext(); + + auto* arg = f_string_obj_deleter_->getArg(0); + arg->setName("base_obj_ptr"); + + auto* entry = llvm::BasicBlock::Create(*ctx, "deleter_entry", f_string_obj_deleter_); + + builder_->SetInsertPoint(entry); + + // Currently, the only pointers stored in a generated StringObj + // are pointers to static allocations made using + // llvm::ConstantDataArray::getString. These static allocations + // should not be deleted after use, so the only allocation that + // needs to be deleted is the `CreateMalloc` containing the + // `StringObj` itself. + llvm::Instruction* free_inst = llvm::CallInst::CreateFree(arg, builder_->GetInsertBlock()); + + builder_->Insert(free_inst); + builder_->CreateRetVoid(); + + builder_->SetInsertPoint(prev_insert_point); + } + + llvm::Value* alloc_size = + llvm::ConstantInt::get(t_int64_, data_layout_->getTypeAllocSize(t_tvm_string_obj_)); + llvm::Instruction* malloc_inst = llvm::CallInst::CreateMalloc( + builder_->GetInsertBlock(), t_int64_, t_tvm_string_obj_, alloc_size, nullptr, nullptr, ""); + + builder_->Insert(malloc_inst); + + llvm::Value* out = builder_->CreatePointerCast(malloc_inst, t_tvm_string_obj_->getPointerTo()); + + llvm::Value* string_obj = builder_->CreateInBoundsGEP(t_tvm_string_obj_, out, ConstInt32(0)); + + builder_->CreateStore(ConstInt32(TypeIndex::kRuntimeString), + builder_->CreateInBoundsGEP(t_tvm_string_obj_, string_obj, + {ConstInt32(0), ConstInt32(0), ConstInt32(0)}, + "output->type_index_")); + + builder_->CreateStore(ConstInt32(1), + builder_->CreateInBoundsGEP(t_tvm_string_obj_, string_obj, + {ConstInt32(0), ConstInt32(0), ConstInt32(1)}, + "output->ref_counter_")); + + builder_->CreateStore(f_string_obj_deleter_, + builder_->CreateInBoundsGEP(t_tvm_string_obj_, string_obj, + {ConstInt32(0), ConstInt32(0), ConstInt32(2)}, + "output->deleter_")); + builder_->CreateStore( + GetConstString(ir_string->value), + builder_->CreateInBoundsGEP(t_tvm_string_obj_, string_obj, {ConstInt32(0), ConstInt32(1)}, + "output->data")); + builder_->CreateStore( + llvm::ConstantInt::getSigned(t_int64_, ir_string->value.size()), + builder_->CreateInBoundsGEP(t_tvm_string_obj_, string_obj, {ConstInt32(0), ConstInt32(2)}, + "output->size")); + + return builder_->CreatePointerCast(out, t_void_p_); +} + llvm::GlobalVariable* CodeGenCPU::InitContextPtr(llvm::Type* p_type, std::string name) { llvm::GlobalVariable* gv = new llvm::GlobalVariable( *module_, p_type, false, llvm::GlobalValue::LinkOnceAnyLinkage, nullptr, name); @@ -1381,6 +1468,10 @@ llvm::Value* CodeGenCPU::CreateIntrinsic(const CallNode* op) { } builder_->CreateStore(value, ref.addr); return ConstInt32(0); + } else if (op->op.same_as(builtin::tvm_string_obj())) { + ICHECK_EQ(op->args.size(), 1U); + ICHECK(op->args[0].dtype() == DataType::Handle()); + return CreateStringObj(Downcast(op->args[0])); } else if (op->op.same_as(builtin::tvm_stack_alloca())) { ICHECK_EQ(op->args.size(), 2U); const std::string& type = op->args[0].as()->value; diff --git a/src/target/llvm/codegen_cpu.h b/src/target/llvm/codegen_cpu.h index 91fe1bc18631..cb4c2531e771 100644 --- a/src/target/llvm/codegen_cpu.h +++ b/src/target/llvm/codegen_cpu.h @@ -76,6 +76,7 @@ class CodeGenCPU : public CodeGenLLVM { llvm::Value* CreateIntrinsic(const CallNode* op) override; llvm::Value* CreateCallExtern(Type ret_type, String global_symbol, const Array& args, bool skip_first_arg) override; + llvm::Value* CreateStringObj(StringImm ir_string); /*! * \brief A CPU-specific function to create the FuncRegistry. @@ -102,6 +103,10 @@ class CodeGenCPU : public CodeGenLLVM { llvm::StructType* t_tvm_value_{nullptr}; llvm::StructType* t_tvm_parallel_group_env_{nullptr}; + llvm::StructType* t_tvm_base_object_{nullptr}; + llvm::StructType* t_tvm_string_obj_{nullptr}; + llvm::Function* f_string_obj_deleter_{nullptr}; + llvm::FunctionType* ftype_tvm_backend_packed_c_func_{nullptr}; llvm::StructType* t_tvm_crt_func_registry_{nullptr}; llvm::StructType* t_tvm_crt_module_{nullptr}; diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc index 0404fd28230e..820b4f65a845 100644 --- a/src/tir/op/builtin.cc +++ b/src/tir/op/builtin.cc @@ -173,6 +173,10 @@ TIR_DEFINE_BUILTIN_FUNC(tvm_struct_set) .set_num_inputs(4) .set_attr("TCallEffectKind", Integer(CallEffectKind::kUpdateState)); +TIR_DEFINE_BUILTIN_FUNC(tvm_string_obj) + .set_num_inputs(1) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + TIR_DEFINE_BUILTIN_FUNC(lookup_param) .set_num_inputs(4) .set_attr("TCallEffectKind", Integer(CallEffectKind::kUpdateState)); diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index 234e034b0342..4df063dc5d38 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -90,9 +90,9 @@ class ReturnRewriter : public StmtMutator { } else if (dtype.is_void()) { info.tcode = kTVMNullptr; info.expr = val; - } else if (val->IsInstance()) { - info.tcode = kTVMStr; - info.expr = val; + } else if (val.as()) { + info.tcode = kTVMObjectHandle; + info.expr = tir::Call(DataType::Handle(), builtin::tvm_string_obj(), {val}); } else { LOG(FATAL) << "data type " << dtype << " not supported yet"; } diff --git a/tests/python/codegen/test_target_codegen_llvm.py b/tests/python/codegen/test_target_codegen_llvm.py index 0e2013642e6f..31d762209b2a 100644 --- a/tests/python/codegen/test_target_codegen_llvm.py +++ b/tests/python/codegen/test_target_codegen_llvm.py @@ -18,10 +18,13 @@ import ctypes import json import math -import numpy as np -import pytest +import os import re import sys +import tempfile + +import numpy as np +import pytest import tvm import tvm.testing @@ -1138,13 +1141,20 @@ def func(): tvm.build(func) -def test_return_string_from_tir(): +@pytest.mark.parametrize("save_and_reload", [True, False]) +def test_return_string_from_tir(save_and_reload): @T.prim_func def func(): return "hello!" built = tvm.build(func, target="llvm") + if save_and_reload: + with tempfile.TemporaryDirectory(prefix="tvm_testing_") as temp_dir: + temp_file = os.path.join(temp_dir, "libbuilt.so") + built.export_library(temp_file) + built = tvm.runtime.load_module(temp_file) + out = built() assert out == "hello!" From dcb4d7d78a114ea3b2be287791e20fcfdb2e019a Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 24 Jun 2024 19:24:17 -0500 Subject: [PATCH 5/8] lint fix --- python/tvm/target/target.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/tvm/target/target.py b/python/tvm/target/target.py index 2882499fe44f..329604e191c9 100644 --- a/python/tvm/target/target.py +++ b/python/tvm/target/target.py @@ -24,7 +24,6 @@ from tvm._ffi import register_func as _register_func from tvm._ffi.runtime_ctypes import Device from tvm.runtime import Object, convert -from tvm.runtime.container import String from tvm.ir.container import Map, Array from . import _ffi_api From e1a47c314867c67ed16d7d7ed3114ffc14acc007 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 27 Jun 2024 10:31:44 -0500 Subject: [PATCH 6/8] Updated rust bindings to handle runtime::String --- include/tvm/runtime/packed_func.h | 9 --------- rust/tvm-rt/src/array.rs | 6 +++--- rust/tvm-rt/src/object/object_ptr.rs | 17 ++++++++++++++++- src/runtime/container.cc | 21 ++++++++++++--------- src/runtime/minrpc/rpc_reference.h | 8 ++++++++ src/runtime/rpc/rpc_local_session.cc | 4 ++++ 6 files changed, 43 insertions(+), 22 deletions(-) diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index 6a3ffd591195..ef37ace19f5e 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -980,15 +980,6 @@ class TVMRetValue : public TVMPODValue_ { TVMRetValue ret; ret.value_ = value; ret.type_code_ = type_code; - - if (ret.type_code_ == kTVMObjectHandle) { - // A C implementation may not have performed the same type - // normalization is performed through the C++ API. For example, - // constructing a `tvm::runtime::String` rather than the - // implementation-dependent `std::string`. - ret = ret.operator ObjectRef(); - } - return ret; } /*! \return The value field, if the data is POD */ diff --git a/rust/tvm-rt/src/array.rs b/rust/tvm-rt/src/array.rs index 02c34a1d133f..f21ba227cd32 100644 --- a/rust/tvm-rt/src/array.rs +++ b/rust/tvm-rt/src/array.rs @@ -40,7 +40,7 @@ pub struct Array { // the implementation. external! { #[name("runtime.ArrayGetItem")] - fn array_get_item(array: ObjectRef, index: isize) -> ObjectRef; + fn array_get_item(array: ObjectRef, index: isize) -> RetValue; #[name("runtime.ArraySize")] fn array_size(array: ObjectRef) -> i64; } @@ -96,8 +96,8 @@ impl Array { where T: TryFrom, { - let oref: ObjectRef = array_get_item(self.object.clone(), index)?; - oref.downcast() + let oref = array_get_item(self.object.clone(), index)?; + oref.try_into() } pub fn len(&self) -> i64 { diff --git a/rust/tvm-rt/src/object/object_ptr.rs b/rust/tvm-rt/src/object/object_ptr.rs index 09d6068f1a88..b1907dbadccc 100644 --- a/rust/tvm-rt/src/object/object_ptr.rs +++ b/rust/tvm-rt/src/object/object_ptr.rs @@ -17,7 +17,7 @@ * under the License. */ -use std::convert::TryFrom; +use std::convert::{TryFrom, TryInto}; use std::ffi::CString; use std::fmt; use std::os::raw::c_char; @@ -31,6 +31,8 @@ use tvm_sys::ffi::{ use tvm_sys::{ArgValue, RetValue}; use crate::errors::Error; +use crate::IsObjectRef; +use crate::String as TVMString; type Deleter = unsafe extern "C" fn(object: *mut Object) -> (); @@ -320,6 +322,19 @@ impl<'a, T: IsObject> TryFrom for ObjectPtr { debug_assert!(optr.count() >= 1); optr.upcast::().downcast() } + RetValue::String(_) | RetValue::Str(_) => { + let string: String = ret_value.try_into().expect("Known to contain a string"); + + let string: TVMString = string.into(); + + let string = string + .into_ptr() + .expect("Known to contain a non-nullptr string"); + + debug_assert!(string.count() >= 1); + + string.upcast::().downcast() + } _ => Err(Error::downcast(format!("{:?}", ret_value), T::TYPE_KEY)), } } diff --git a/src/runtime/container.cc b/src/runtime/container.cc index 7b5105a3fc94..7d49e0bef427 100644 --- a/src/runtime/container.cc +++ b/src/runtime/container.cc @@ -49,15 +49,18 @@ TVM_REGISTER_GLOBAL("runtime.Array").set_body([](TVMArgs args, TVMRetValue* ret) *ret = Array(data); }); -TVM_REGISTER_GLOBAL("runtime.ArrayGetItem").set_body([](TVMArgs args, TVMRetValue* ret) { - int64_t i = args[1]; - ICHECK_EQ(args[0].type_code(), kTVMObjectHandle); - Object* ptr = static_cast(args[0].value().v_handle); - ICHECK(ptr->IsInstance()); - auto* n = static_cast(ptr); - ICHECK_LT(static_cast(i), n->size()) << "out of bound of array"; - *ret = n->at(i); -}); +// TVM_REGISTER_GLOBAL("runtime.ArrayGetItem").set_body([](TVMArgs args, TVMRetValue* ret) { +// int64_t i = args[1]; +// ICHECK_EQ(args[0].type_code(), kTVMObjectHandle); +// Object* ptr = static_cast(args[0].value().v_handle); +// ICHECK(ptr->IsInstance()); +// auto* n = static_cast(ptr); +// ICHECK_LT(static_cast(i), n->size()) << "out of bound of array"; +// *ret = n->at(i); +// }); + +TVM_REGISTER_GLOBAL("runtime.ArrayGetItem") + .set_body_typed([](Array arr, size_t index) -> ObjectRef { return arr[index]; }); TVM_REGISTER_GLOBAL("runtime.ArraySize").set_body([](TVMArgs args, TVMRetValue* ret) { ICHECK_EQ(args[0].type_code(), kTVMObjectHandle); diff --git a/src/runtime/minrpc/rpc_reference.h b/src/runtime/minrpc/rpc_reference.h index d08dadb02bb9..bdd34ea4da49 100644 --- a/src/runtime/minrpc/rpc_reference.h +++ b/src/runtime/minrpc/rpc_reference.h @@ -33,6 +33,14 @@ class Object; /*! \brief The current RPC procotol version. */ constexpr const char* kRPCProtocolVer = "0.8.0"; +/*! + * \brief type index of kRuntimeString + * \note this needs to be kept consistent with runtime/object.h + * but we explicitly declare it here because minrpc needs to be minimum dep + * only c C API + */ +constexpr const int kRuntimeString = 3; + /*! * \brief type index of kRuntimeRPCObjectRefTypeIndex * \note this needs to be kept consistent with runtime/object.h diff --git a/src/runtime/rpc/rpc_local_session.cc b/src/runtime/rpc/rpc_local_session.cc index 92691ee6fd28..25059a74797f 100644 --- a/src/runtime/rpc/rpc_local_session.cc +++ b/src/runtime/rpc/rpc_local_session.cc @@ -47,6 +47,10 @@ RPCSession::PackedFuncHandle LocalSession::GetFunction(const std::string& name) } void LocalSession::EncodeReturn(TVMRetValue rv, const FEncodeReturn& encode_return) { + if (rv.type_code() == kTVMObjectHandle && rv.IsObjectRef()) { + rv = std::string(rv.AsObjectRef()); + } + int rv_tcode = rv.type_code(); // return value encoding. From ad4705198a8f7d2f80a494d4a24a949afdb55634 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 11 Jul 2024 12:41:24 -0500 Subject: [PATCH 7/8] Replace python tvm.runtime.String with alias --- python/tvm/runtime/container.py | 31 +++++++++---------------------- 1 file changed, 9 insertions(+), 22 deletions(-) diff --git a/python/tvm/runtime/container.py b/python/tvm/runtime/container.py index d23df9d41841..5de6c9b6d374 100644 --- a/python/tvm/runtime/container.py +++ b/python/tvm/runtime/container.py @@ -16,7 +16,7 @@ # under the License. """Runtime container structures.""" import tvm._ffi -from .object import Object, PyNativeObject +from .object import Object from .object_generic import ObjectTypes from . import _ffi_api @@ -112,29 +112,16 @@ def tuple_object(fields=None): return _ffi_api.Tuple(*fields) -@tvm._ffi.register_object("runtime.String") -class String(str, PyNativeObject): - """TVM runtime.String object, represented as a python str. +String = str +"""Backwards-compatibility alias - Parameters - ---------- - content : str - The content string used to construct the object. - """ - - __slots__ = ["__tvm_object__"] - - def __new__(cls, content): - """Construct from string content.""" - val = str.__new__(cls, content) - val.__init_tvm_object_by_constructor__(_ffi_api.String, content) - return val - - # pylint: disable=no-self-argument - def __from_tvm_object__(cls, obj: Object) -> str: - """Convert from runtime.String to native string""" +In previous implementations, when the C++ type `tvm::runtime::String` +was stored into a TVMRetValue, it used the type code kTVMObjectHandle. +It is now converted on storage into a TVMRetValue with type code +kTVMStr, removing the need for a separate `tvm.runtime.String` class. +This alias is maintained for backwards compatibility. - return _ffi_api.GetFFIString(obj) +""" @tvm._ffi.register_object("runtime.ShapeTuple") From 20573fad520ce697814782e521608d9ca0afc315 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 11 Jul 2024 15:03:26 -0500 Subject: [PATCH 8/8] Revert unnecessary change to runtime.ArrayGetItem --- src/runtime/container.cc | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/src/runtime/container.cc b/src/runtime/container.cc index 7d49e0bef427..7b5105a3fc94 100644 --- a/src/runtime/container.cc +++ b/src/runtime/container.cc @@ -49,18 +49,15 @@ TVM_REGISTER_GLOBAL("runtime.Array").set_body([](TVMArgs args, TVMRetValue* ret) *ret = Array(data); }); -// TVM_REGISTER_GLOBAL("runtime.ArrayGetItem").set_body([](TVMArgs args, TVMRetValue* ret) { -// int64_t i = args[1]; -// ICHECK_EQ(args[0].type_code(), kTVMObjectHandle); -// Object* ptr = static_cast(args[0].value().v_handle); -// ICHECK(ptr->IsInstance()); -// auto* n = static_cast(ptr); -// ICHECK_LT(static_cast(i), n->size()) << "out of bound of array"; -// *ret = n->at(i); -// }); - -TVM_REGISTER_GLOBAL("runtime.ArrayGetItem") - .set_body_typed([](Array arr, size_t index) -> ObjectRef { return arr[index]; }); +TVM_REGISTER_GLOBAL("runtime.ArrayGetItem").set_body([](TVMArgs args, TVMRetValue* ret) { + int64_t i = args[1]; + ICHECK_EQ(args[0].type_code(), kTVMObjectHandle); + Object* ptr = static_cast(args[0].value().v_handle); + ICHECK(ptr->IsInstance()); + auto* n = static_cast(ptr); + ICHECK_LT(static_cast(i), n->size()) << "out of bound of array"; + *ret = n->at(i); +}); TVM_REGISTER_GLOBAL("runtime.ArraySize").set_body([](TVMArgs args, TVMRetValue* ret) { ICHECK_EQ(args[0].type_code(), kTVMObjectHandle);