diff --git a/include/tvm/relax/dataflow_pattern.h b/include/tvm/relax/dataflow_pattern.h index f7094b221221..28aef60b80cb 100644 --- a/include/tvm/relax/dataflow_pattern.h +++ b/include/tvm/relax/dataflow_pattern.h @@ -277,19 +277,8 @@ class PatternContextNode : public Object { */ class PatternContext : public ObjectRef { public: - TVM_DLL explicit PatternContext(ObjectPtr n) : ObjectRef(n) {} TVM_DLL explicit PatternContext(bool incremental = false); - const PatternContextNode* operator->() const { - ICHECK(get() != nullptr); - return static_cast(get()); - } - - PatternContextNode* operator->() { - ICHECK(get() != nullptr); - return static_cast(get_mutable()); - } - /*! * \brief Build an edge constraint between two patterns (producer and consumer). * @@ -333,6 +322,8 @@ class PatternContext : public ObjectRef { /*! \brief The RAII-like exit of a constraint context scope */ TVM_DLL void ExitWithScope() const; + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(PatternContext, ObjectRef, PatternContextNode); + private: friend class With; }; diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index 7266f8c4a50a..ef37ace19f5e 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -27,6 +27,7 @@ #include #include #include +#include #include #include #include @@ -975,6 +976,7 @@ class TVMRetValue : public TVMPODValue_ { static TVMRetValue MoveFromCHost(TVMValue value, int type_code) { // Can move POD and everything under the object system. ICHECK(type_code <= kTVMPackedFuncHandle || type_code == kTVMNDArrayHandle); + TVMRetValue ret; ret.value_ = value; ret.type_code_ = type_code; @@ -2076,10 +2078,14 @@ inline TObjectRef TVMPODValue_::AsObjectRef() const { type_code_ == kTVMPackedFuncHandle) { // Casting to a base class that PackedFunc can sub-class return TObjectRef(GetObjectPtr(static_cast(value_.v_handle))); - } else { - TVM_CHECK_TYPE_CODE(type_code_, kTVMObjectHandle); - return TObjectRef(ObjectPtr(nullptr)); + } else if constexpr (std::is_base_of::value) { + if (type_code_ == kTVMStr) { + return runtime::String(value_.v_str); + } } + + TVM_CHECK_TYPE_CODE(type_code_, kTVMObjectHandle); + return TObjectRef(ObjectPtr(nullptr)); } template @@ -2102,6 +2108,13 @@ inline TVMRetValue& TVMRetValue::operator=(TObjectRef other) { ptr->IsInstance())) { return operator=(PackedFunc(std::move(other.data_))); } + if (std::is_base_of::value || + (std::is_base_of::value && + ptr->IsInstance())) { + const auto* string_obj = other.template as(); + return operator=(std::string(string_obj->data, string_obj->size)); + } + SwitchToObject(kTVMObjectHandle, std::move(other.data_)); } else { SwitchToPOD(kTVMNullptr); diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h index 120c1b71be72..b8e0ad6a1b0a 100644 --- a/include/tvm/tir/builtin.h +++ b/include/tvm/tir/builtin.h @@ -283,6 +283,15 @@ TVM_DLL const Op& tvm_struct_get(); */ TVM_DLL const Op& tvm_struct_set(); +/*! + * \brief TIR constructor for tvm::runtime::StringObj + * + * runtime::StringObj* tvm_string_obj(StringImm value) { + * return new StringObj(value); + * } + */ +TVM_DLL const Op& tvm_string_obj(); + /*! * \brief See pseudo code * Type lookup_param(String param_name) { diff --git a/python/tvm/runtime/container.py b/python/tvm/runtime/container.py index 686b4a26c80c..5de6c9b6d374 100644 --- a/python/tvm/runtime/container.py +++ b/python/tvm/runtime/container.py @@ -16,7 +16,7 @@ # under the License. """Runtime container structures.""" import tvm._ffi -from .object import Object, PyNativeObject +from .object import Object from .object_generic import ObjectTypes from . import _ffi_api @@ -112,31 +112,16 @@ def tuple_object(fields=None): return _ffi_api.Tuple(*fields) -@tvm._ffi.register_object("runtime.String") -class String(str, PyNativeObject): - """TVM runtime.String object, represented as a python str. +String = str +"""Backwards-compatibility alias - Parameters - ---------- - content : str - The content string used to construct the object. - """ +In previous implementations, when the C++ type `tvm::runtime::String` +was stored into a TVMRetValue, it used the type code kTVMObjectHandle. +It is now converted on storage into a TVMRetValue with type code +kTVMStr, removing the need for a separate `tvm.runtime.String` class. +This alias is maintained for backwards compatibility. - __slots__ = ["__tvm_object__"] - - def __new__(cls, content): - """Construct from string content.""" - val = str.__new__(cls, content) - val.__init_tvm_object_by_constructor__(_ffi_api.String, content) - return val - - # pylint: disable=no-self-argument - def __from_tvm_object__(cls, obj): - """Construct from a given tvm object.""" - content = _ffi_api.GetFFIString(obj) - val = str.__new__(cls, content) - val.__tvm_object__ = obj - return val +""" @tvm._ffi.register_object("runtime.ShapeTuple") diff --git a/python/tvm/target/target.py b/python/tvm/target/target.py index ec74cbcdb62a..329604e191c9 100644 --- a/python/tvm/target/target.py +++ b/python/tvm/target/target.py @@ -24,7 +24,6 @@ from tvm._ffi import register_func as _register_func from tvm._ffi.runtime_ctypes import Device from tvm.runtime import Object, convert -from tvm.runtime.container import String from tvm.ir.container import Map, Array from . import _ffi_api @@ -128,10 +127,10 @@ def __init__(self, target, host=None): target = convert(target) if isinstance(host, (dict, str)): host = convert(host) - if target is None or not isinstance(target, (Map, String, Target)): + if target is None or not isinstance(target, (Map, str, Target)): raise ValueError("target has to be a string or dictionary.") if host is not None: - if not isinstance(host, (Map, String, Target)): + if not isinstance(host, (Map, str, Target)): raise ValueError("target host has to be a string or dictionary.") self.__init_handle_by_constructor__(_ffi_api.Target, Target(target), Target(host)) else: diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index 81d6604259a3..901b653e1d45 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -1827,7 +1827,10 @@ def ret(val): The return expression """ - val = convert(val) + if isinstance(val, str): + val = StringImm(val) + else: + val = convert(val) return call_intrin(val.dtype, "tir.ret", val) diff --git a/rust/tvm-rt/src/array.rs b/rust/tvm-rt/src/array.rs index 02c34a1d133f..f21ba227cd32 100644 --- a/rust/tvm-rt/src/array.rs +++ b/rust/tvm-rt/src/array.rs @@ -40,7 +40,7 @@ pub struct Array { // the implementation. external! { #[name("runtime.ArrayGetItem")] - fn array_get_item(array: ObjectRef, index: isize) -> ObjectRef; + fn array_get_item(array: ObjectRef, index: isize) -> RetValue; #[name("runtime.ArraySize")] fn array_size(array: ObjectRef) -> i64; } @@ -96,8 +96,8 @@ impl Array { where T: TryFrom, { - let oref: ObjectRef = array_get_item(self.object.clone(), index)?; - oref.downcast() + let oref = array_get_item(self.object.clone(), index)?; + oref.try_into() } pub fn len(&self) -> i64 { diff --git a/rust/tvm-rt/src/object/object_ptr.rs b/rust/tvm-rt/src/object/object_ptr.rs index 09d6068f1a88..b1907dbadccc 100644 --- a/rust/tvm-rt/src/object/object_ptr.rs +++ b/rust/tvm-rt/src/object/object_ptr.rs @@ -17,7 +17,7 @@ * under the License. */ -use std::convert::TryFrom; +use std::convert::{TryFrom, TryInto}; use std::ffi::CString; use std::fmt; use std::os::raw::c_char; @@ -31,6 +31,8 @@ use tvm_sys::ffi::{ use tvm_sys::{ArgValue, RetValue}; use crate::errors::Error; +use crate::IsObjectRef; +use crate::String as TVMString; type Deleter = unsafe extern "C" fn(object: *mut Object) -> (); @@ -320,6 +322,19 @@ impl<'a, T: IsObject> TryFrom for ObjectPtr { debug_assert!(optr.count() >= 1); optr.upcast::().downcast() } + RetValue::String(_) | RetValue::Str(_) => { + let string: String = ret_value.try_into().expect("Known to contain a string"); + + let string: TVMString = string.into(); + + let string = string + .into_ptr() + .expect("Known to contain a non-nullptr string"); + + debug_assert!(string.count() >= 1); + + string.upcast::().downcast() + } _ => Err(Error::downcast(format!("{:?}", ret_value), T::TYPE_KEY)), } } diff --git a/src/relay/analysis/annotated_region_set.h b/src/relay/analysis/annotated_region_set.h index 443bd5ec1da3..4c5ec1cfdf37 100644 --- a/src/relay/analysis/annotated_region_set.h +++ b/src/relay/analysis/annotated_region_set.h @@ -210,13 +210,6 @@ class AnnotatedRegionSet : public ObjectRef { data_ = std::move(n); } - /*! - * \brief Construct from an object pointer. - * - * \param n The object pointer. - */ - explicit AnnotatedRegionSet(ObjectPtr n) : ObjectRef(n) {} - /*! \return The begin iterator. */ iterator begin() { auto* n = operator->(); @@ -242,13 +235,6 @@ class AnnotatedRegionSet : public ObjectRef { return n->end(); } - /*! \return mutable pointers to the node. */ - AnnotatedRegionSetNode* operator->() const { - auto* ptr = get_mutable(); - ICHECK(ptr != nullptr); - return static_cast(ptr); - } - /*! \return The region an expression belongs to. */ AnnotatedRegion operator[](const Expr& expr) { const auto* n = operator->(); @@ -268,6 +254,9 @@ class AnnotatedRegionSet : public ObjectRef { static AnnotatedRegionSet Create(const Expr& expr, const Op& begin, const Op& end, const std::string& func_name = "default"); + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(AnnotatedRegionSet, ObjectRef, + AnnotatedRegionSetNode); + private: /*! \brief Helper class to construct a RegionSet from an expr.*/ class Creator; diff --git a/src/relay/analysis/call_graph.h b/src/relay/analysis/call_graph.h index 091891acd414..54ed00868360 100644 --- a/src/relay/analysis/call_graph.h +++ b/src/relay/analysis/call_graph.h @@ -207,12 +207,6 @@ class CallGraph : public ObjectRef { */ explicit CallGraph(IRModule module); - /*! - * \brief Construct from an object pointer. - * \param n The object pointer. - */ - explicit CallGraph(ObjectPtr n) : ObjectRef(n) {} - /*! \return The begin iterator. */ iterator begin() { auto* n = operator->(); @@ -287,12 +281,7 @@ class CallGraph : public ObjectRef { return (*n)[gvar_name]; } - /*! \return mutable pointers to the node. */ - CallGraphNode* operator->() const { - auto* ptr = get_mutable(); - ICHECK(ptr != nullptr); - return static_cast(ptr); - } + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(CallGraph, ObjectRef, CallGraphNode); private: /*! \brief Overload the << operator to print a call graph. */ diff --git a/src/runtime/minrpc/rpc_reference.h b/src/runtime/minrpc/rpc_reference.h index d08dadb02bb9..bdd34ea4da49 100644 --- a/src/runtime/minrpc/rpc_reference.h +++ b/src/runtime/minrpc/rpc_reference.h @@ -33,6 +33,14 @@ class Object; /*! \brief The current RPC procotol version. */ constexpr const char* kRPCProtocolVer = "0.8.0"; +/*! + * \brief type index of kRuntimeString + * \note this needs to be kept consistent with runtime/object.h + * but we explicitly declare it here because minrpc needs to be minimum dep + * only c C API + */ +constexpr const int kRuntimeString = 3; + /*! * \brief type index of kRuntimeRPCObjectRefTypeIndex * \note this needs to be kept consistent with runtime/object.h diff --git a/src/runtime/rpc/rpc_local_session.cc b/src/runtime/rpc/rpc_local_session.cc index 92691ee6fd28..25059a74797f 100644 --- a/src/runtime/rpc/rpc_local_session.cc +++ b/src/runtime/rpc/rpc_local_session.cc @@ -47,6 +47,10 @@ RPCSession::PackedFuncHandle LocalSession::GetFunction(const std::string& name) } void LocalSession::EncodeReturn(TVMRetValue rv, const FEncodeReturn& encode_return) { + if (rv.type_code() == kTVMObjectHandle && rv.IsObjectRef()) { + rv = std::string(rv.AsObjectRef()); + } + int rv_tcode = rv.type_code(); // return value encoding. diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index 481ba39cc7b1..f21399a111be 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -399,6 +399,93 @@ llvm::Value* CodeGenCPU::CreateCallExtern(Type ret_type, String global_symbol, return builder_->CreateCall(ftype, callee, arg_values); } +llvm::Value* CodeGenCPU::CreateStringObj(StringImm ir_string) { + auto t_deleter = llvm::FunctionType::get(t_void_, {t_void_p_}, false); + + if (!t_tvm_base_object_) { + t_tvm_base_object_ = llvm::StructType::create( + { + t_int32_ /* type_index_ */, + t_int32_ /* ref_counter_ */, + t_deleter->getPointerTo() /* deleter_ */, + }, + "tvm::runtime::Object", true); + } + if (!t_tvm_string_obj_) { + t_tvm_string_obj_ = llvm::StructType::create( + { + t_tvm_base_object_, + t_char_->getPointerTo() /* data */, + t_int64_ /* size */, + }, + "tvm::runtime::StringObj"); + } + if (!f_string_obj_deleter_) { + auto prev_insert_point = builder_->GetInsertBlock(); + + f_string_obj_deleter_ = llvm::Function::Create(t_deleter, llvm::Function::PrivateLinkage, + "string_obj_deleter", module_.get()); + + llvm::LLVMContext* ctx = llvm_target_->GetContext(); + + auto* arg = f_string_obj_deleter_->getArg(0); + arg->setName("base_obj_ptr"); + + auto* entry = llvm::BasicBlock::Create(*ctx, "deleter_entry", f_string_obj_deleter_); + + builder_->SetInsertPoint(entry); + + // Currently, the only pointers stored in a generated StringObj + // are pointers to static allocations made using + // llvm::ConstantDataArray::getString. These static allocations + // should not be deleted after use, so the only allocation that + // needs to be deleted is the `CreateMalloc` containing the + // `StringObj` itself. + llvm::Instruction* free_inst = llvm::CallInst::CreateFree(arg, builder_->GetInsertBlock()); + + builder_->Insert(free_inst); + builder_->CreateRetVoid(); + + builder_->SetInsertPoint(prev_insert_point); + } + + llvm::Value* alloc_size = + llvm::ConstantInt::get(t_int64_, data_layout_->getTypeAllocSize(t_tvm_string_obj_)); + llvm::Instruction* malloc_inst = llvm::CallInst::CreateMalloc( + builder_->GetInsertBlock(), t_int64_, t_tvm_string_obj_, alloc_size, nullptr, nullptr, ""); + + builder_->Insert(malloc_inst); + + llvm::Value* out = builder_->CreatePointerCast(malloc_inst, t_tvm_string_obj_->getPointerTo()); + + llvm::Value* string_obj = builder_->CreateInBoundsGEP(t_tvm_string_obj_, out, ConstInt32(0)); + + builder_->CreateStore(ConstInt32(TypeIndex::kRuntimeString), + builder_->CreateInBoundsGEP(t_tvm_string_obj_, string_obj, + {ConstInt32(0), ConstInt32(0), ConstInt32(0)}, + "output->type_index_")); + + builder_->CreateStore(ConstInt32(1), + builder_->CreateInBoundsGEP(t_tvm_string_obj_, string_obj, + {ConstInt32(0), ConstInt32(0), ConstInt32(1)}, + "output->ref_counter_")); + + builder_->CreateStore(f_string_obj_deleter_, + builder_->CreateInBoundsGEP(t_tvm_string_obj_, string_obj, + {ConstInt32(0), ConstInt32(0), ConstInt32(2)}, + "output->deleter_")); + builder_->CreateStore( + GetConstString(ir_string->value), + builder_->CreateInBoundsGEP(t_tvm_string_obj_, string_obj, {ConstInt32(0), ConstInt32(1)}, + "output->data")); + builder_->CreateStore( + llvm::ConstantInt::getSigned(t_int64_, ir_string->value.size()), + builder_->CreateInBoundsGEP(t_tvm_string_obj_, string_obj, {ConstInt32(0), ConstInt32(2)}, + "output->size")); + + return builder_->CreatePointerCast(out, t_void_p_); +} + llvm::GlobalVariable* CodeGenCPU::InitContextPtr(llvm::Type* p_type, std::string name) { llvm::GlobalVariable* gv = new llvm::GlobalVariable( *module_, p_type, false, llvm::GlobalValue::LinkOnceAnyLinkage, nullptr, name); @@ -1381,6 +1468,10 @@ llvm::Value* CodeGenCPU::CreateIntrinsic(const CallNode* op) { } builder_->CreateStore(value, ref.addr); return ConstInt32(0); + } else if (op->op.same_as(builtin::tvm_string_obj())) { + ICHECK_EQ(op->args.size(), 1U); + ICHECK(op->args[0].dtype() == DataType::Handle()); + return CreateStringObj(Downcast(op->args[0])); } else if (op->op.same_as(builtin::tvm_stack_alloca())) { ICHECK_EQ(op->args.size(), 2U); const std::string& type = op->args[0].as()->value; diff --git a/src/target/llvm/codegen_cpu.h b/src/target/llvm/codegen_cpu.h index 91fe1bc18631..cb4c2531e771 100644 --- a/src/target/llvm/codegen_cpu.h +++ b/src/target/llvm/codegen_cpu.h @@ -76,6 +76,7 @@ class CodeGenCPU : public CodeGenLLVM { llvm::Value* CreateIntrinsic(const CallNode* op) override; llvm::Value* CreateCallExtern(Type ret_type, String global_symbol, const Array& args, bool skip_first_arg) override; + llvm::Value* CreateStringObj(StringImm ir_string); /*! * \brief A CPU-specific function to create the FuncRegistry. @@ -102,6 +103,10 @@ class CodeGenCPU : public CodeGenLLVM { llvm::StructType* t_tvm_value_{nullptr}; llvm::StructType* t_tvm_parallel_group_env_{nullptr}; + llvm::StructType* t_tvm_base_object_{nullptr}; + llvm::StructType* t_tvm_string_obj_{nullptr}; + llvm::Function* f_string_obj_deleter_{nullptr}; + llvm::FunctionType* ftype_tvm_backend_packed_c_func_{nullptr}; llvm::StructType* t_tvm_crt_func_registry_{nullptr}; llvm::StructType* t_tvm_crt_module_{nullptr}; diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc index 0404fd28230e..820b4f65a845 100644 --- a/src/tir/op/builtin.cc +++ b/src/tir/op/builtin.cc @@ -173,6 +173,10 @@ TIR_DEFINE_BUILTIN_FUNC(tvm_struct_set) .set_num_inputs(4) .set_attr("TCallEffectKind", Integer(CallEffectKind::kUpdateState)); +TIR_DEFINE_BUILTIN_FUNC(tvm_string_obj) + .set_num_inputs(1) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + TIR_DEFINE_BUILTIN_FUNC(lookup_param) .set_num_inputs(4) .set_attr("TCallEffectKind", Integer(CallEffectKind::kUpdateState)); diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index d327cdfa8393..4df063dc5d38 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -90,6 +90,9 @@ class ReturnRewriter : public StmtMutator { } else if (dtype.is_void()) { info.tcode = kTVMNullptr; info.expr = val; + } else if (val.as()) { + info.tcode = kTVMObjectHandle; + info.expr = tir::Call(DataType::Handle(), builtin::tvm_string_obj(), {val}); } else { LOG(FATAL) << "data type " << dtype << " not supported yet"; } diff --git a/tests/python/codegen/test_target_codegen_llvm.py b/tests/python/codegen/test_target_codegen_llvm.py index f50d63878e4f..31d762209b2a 100644 --- a/tests/python/codegen/test_target_codegen_llvm.py +++ b/tests/python/codegen/test_target_codegen_llvm.py @@ -18,10 +18,13 @@ import ctypes import json import math -import numpy as np -import pytest +import os import re import sys +import tempfile + +import numpy as np +import pytest import tvm import tvm.testing @@ -1138,5 +1141,23 @@ def func(): tvm.build(func) +@pytest.mark.parametrize("save_and_reload", [True, False]) +def test_return_string_from_tir(save_and_reload): + @T.prim_func + def func(): + return "hello!" + + built = tvm.build(func, target="llvm") + + if save_and_reload: + with tempfile.TemporaryDirectory(prefix="tvm_testing_") as temp_dir: + temp_file = os.path.join(temp_dir, "libbuilt.so") + built.export_library(temp_file) + built = tvm.runtime.load_module(temp_file) + + out = built() + assert out == "hello!" + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/runtime/test_runtime_rpc.py b/tests/python/runtime/test_runtime_rpc.py index fbdc33928b6e..d9d9dbc7924a 100644 --- a/tests/python/runtime/test_runtime_rpc.py +++ b/tests/python/runtime/test_runtime_rpc.py @@ -734,3 +734,25 @@ def func_with_arg(unused: T.int64) -> T.int64: res = remote_mod["func_without_arg"]() assert res == 42 + + +def test_return_string_over_rpc(): + @T.prim_func + def func(unused: T.int64) -> T.handle: + return T.StringImm("hello!") + + built = tvm.build(func, target="llvm") + + server = tvm.rpc.Server(key="x1") + client = tvm.rpc.connect("127.0.0.1", server.port, key="x1") + + libname = "libbuilt.so" + with tempfile.TemporaryDirectory(prefix="tvm_rpc_testing_") as temp_dir: + local_path = os.path.join(temp_dir, libname) + built.export_library(local_path) + client.upload(local_path) + + remote_mod = client.load_module(libname) + + out = remote_mod(42) + assert out == "hello!" diff --git a/tests/python/tvmscript/test_tvmscript_roundtrip.py b/tests/python/tvmscript/test_tvmscript_roundtrip.py index f81a80de6d61..2fdfdfc65af5 100644 --- a/tests/python/tvmscript/test_tvmscript_roundtrip.py +++ b/tests/python/tvmscript/test_tvmscript_roundtrip.py @@ -4114,6 +4114,21 @@ def func(A: R.Tensor(["N"], "float16"), _: R.Prim(value="threshold")): return func +def tir_return_string_imm(): + """TIR StringImm must round-trip + + The conversion from Python str to TIR StringImm occurs at the + callee. + + """ + + @T.prim_func + def func(): + return T.StringImm("hello") + + return func + + ir_generator = tvm.testing.parameter( launch_env_thread, opt_gemm_normalize, @@ -4202,6 +4217,7 @@ def func(A: R.Tensor(["N"], "float16"), _: R.Prim(value="threshold")): *relax_match_cast_struct_info_proxy(), relax_symbolic_size_var, relax_float_symbolic_var, + tir_return_string_imm, ) relax_ir_generator = tvm.testing.parameter(