Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 2 additions & 11 deletions include/tvm/relax/dataflow_pattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -277,19 +277,8 @@ class PatternContextNode : public Object {
*/
class PatternContext : public ObjectRef {
public:
TVM_DLL explicit PatternContext(ObjectPtr<Object> n) : ObjectRef(n) {}
TVM_DLL explicit PatternContext(bool incremental = false);

const PatternContextNode* operator->() const {
ICHECK(get() != nullptr);
return static_cast<const PatternContextNode*>(get());
}

PatternContextNode* operator->() {
ICHECK(get() != nullptr);
return static_cast<PatternContextNode*>(get_mutable());
}

/*!
* \brief Build an edge constraint between two patterns (producer and consumer).
*
Expand Down Expand Up @@ -333,6 +322,8 @@ class PatternContext : public ObjectRef {
/*! \brief The RAII-like exit of a constraint context scope */
TVM_DLL void ExitWithScope() const;

TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(PatternContext, ObjectRef, PatternContextNode);

private:
friend class With<PatternContext>;
};
Expand Down
19 changes: 16 additions & 3 deletions include/tvm/runtime/packed_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/container/array.h>
#include <tvm/runtime/container/map.h>
#include <tvm/runtime/container/string.h>
#include <tvm/runtime/container/variant.h>
#include <tvm/runtime/data_type.h>
#include <tvm/runtime/logging.h>
Expand Down Expand Up @@ -975,6 +976,7 @@ class TVMRetValue : public TVMPODValue_ {
static TVMRetValue MoveFromCHost(TVMValue value, int type_code) {
// Can move POD and everything under the object system.
ICHECK(type_code <= kTVMPackedFuncHandle || type_code == kTVMNDArrayHandle);

TVMRetValue ret;
ret.value_ = value;
ret.type_code_ = type_code;
Expand Down Expand Up @@ -2076,10 +2078,14 @@ inline TObjectRef TVMPODValue_::AsObjectRef() const {
type_code_ == kTVMPackedFuncHandle) {
// Casting to a base class that PackedFunc can sub-class
return TObjectRef(GetObjectPtr<Object>(static_cast<Object*>(value_.v_handle)));
} else {
TVM_CHECK_TYPE_CODE(type_code_, kTVMObjectHandle);
return TObjectRef(ObjectPtr<Object>(nullptr));
} else if constexpr (std::is_base_of<ContainerType, StringObj>::value) {
if (type_code_ == kTVMStr) {
return runtime::String(value_.v_str);
}
}

TVM_CHECK_TYPE_CODE(type_code_, kTVMObjectHandle);
return TObjectRef(ObjectPtr<Object>(nullptr));
}

template <typename TObjectRef, typename>
Expand All @@ -2102,6 +2108,13 @@ inline TVMRetValue& TVMRetValue::operator=(TObjectRef other) {
ptr->IsInstance<PackedFunc::ContainerType>())) {
return operator=(PackedFunc(std::move(other.data_)));
}
if (std::is_base_of<runtime::String::ContainerType, ContainerType>::value ||
(std::is_base_of<ContainerType, runtime::String::ContainerType>::value &&
ptr->IsInstance<runtime::String::ContainerType>())) {
const auto* string_obj = other.template as<runtime::String::ContainerType>();
return operator=(std::string(string_obj->data, string_obj->size));
}

SwitchToObject(kTVMObjectHandle, std::move(other.data_));
} else {
SwitchToPOD(kTVMNullptr);
Expand Down
9 changes: 9 additions & 0 deletions include/tvm/tir/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,15 @@ TVM_DLL const Op& tvm_struct_get();
*/
TVM_DLL const Op& tvm_struct_set();

/*!
* \brief TIR constructor for tvm::runtime::StringObj
*
* runtime::StringObj* tvm_string_obj(StringImm value) {
* return new StringObj(value);
* }
*/
TVM_DLL const Op& tvm_string_obj();

/*!
* \brief See pseudo code
* Type lookup_param(String param_name) {
Expand Down
33 changes: 9 additions & 24 deletions python/tvm/runtime/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.
"""Runtime container structures."""
import tvm._ffi
from .object import Object, PyNativeObject
from .object import Object
from .object_generic import ObjectTypes
from . import _ffi_api

Expand Down Expand Up @@ -112,31 +112,16 @@ def tuple_object(fields=None):
return _ffi_api.Tuple(*fields)


@tvm._ffi.register_object("runtime.String")
class String(str, PyNativeObject):
"""TVM runtime.String object, represented as a python str.
String = str
"""Backwards-compatibility alias

Parameters
----------
content : str
The content string used to construct the object.
"""
In previous implementations, when the C++ type `tvm::runtime::String`
was stored into a TVMRetValue, it used the type code kTVMObjectHandle.
It is now converted on storage into a TVMRetValue with type code
kTVMStr, removing the need for a separate `tvm.runtime.String` class.
This alias is maintained for backwards compatibility.

__slots__ = ["__tvm_object__"]

def __new__(cls, content):
"""Construct from string content."""
val = str.__new__(cls, content)
val.__init_tvm_object_by_constructor__(_ffi_api.String, content)
return val

# pylint: disable=no-self-argument
def __from_tvm_object__(cls, obj):
"""Construct from a given tvm object."""
content = _ffi_api.GetFFIString(obj)
val = str.__new__(cls, content)
val.__tvm_object__ = obj
return val
"""


@tvm._ffi.register_object("runtime.ShapeTuple")
Expand Down
5 changes: 2 additions & 3 deletions python/tvm/target/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from tvm._ffi import register_func as _register_func
from tvm._ffi.runtime_ctypes import Device
from tvm.runtime import Object, convert
from tvm.runtime.container import String
from tvm.ir.container import Map, Array

from . import _ffi_api
Expand Down Expand Up @@ -128,10 +127,10 @@ def __init__(self, target, host=None):
target = convert(target)
if isinstance(host, (dict, str)):
host = convert(host)
if target is None or not isinstance(target, (Map, String, Target)):
if target is None or not isinstance(target, (Map, str, Target)):
raise ValueError("target has to be a string or dictionary.")
if host is not None:
if not isinstance(host, (Map, String, Target)):
if not isinstance(host, (Map, str, Target)):
raise ValueError("target host has to be a string or dictionary.")
self.__init_handle_by_constructor__(_ffi_api.Target, Target(target), Target(host))
else:
Expand Down
5 changes: 4 additions & 1 deletion python/tvm/tir/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1827,7 +1827,10 @@ def ret(val):
The return expression
"""

val = convert(val)
if isinstance(val, str):
val = StringImm(val)
else:
val = convert(val)
return call_intrin(val.dtype, "tir.ret", val)


Expand Down
6 changes: 3 additions & 3 deletions rust/tvm-rt/src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ pub struct Array<T: IsObjectRef> {
// the implementation.
external! {
#[name("runtime.ArrayGetItem")]
fn array_get_item(array: ObjectRef, index: isize) -> ObjectRef;
fn array_get_item(array: ObjectRef, index: isize) -> RetValue;
#[name("runtime.ArraySize")]
fn array_size(array: ObjectRef) -> i64;
}
Expand Down Expand Up @@ -96,8 +96,8 @@ impl<T: IsObjectRef> Array<T> {
where
T: TryFrom<RetValue, Error = Error>,
{
let oref: ObjectRef = array_get_item(self.object.clone(), index)?;
oref.downcast()
let oref = array_get_item(self.object.clone(), index)?;
oref.try_into()
}

pub fn len(&self) -> i64 {
Expand Down
17 changes: 16 additions & 1 deletion rust/tvm-rt/src/object/object_ptr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
* under the License.
*/

use std::convert::TryFrom;
use std::convert::{TryFrom, TryInto};
use std::ffi::CString;
use std::fmt;
use std::os::raw::c_char;
Expand All @@ -31,6 +31,8 @@ use tvm_sys::ffi::{
use tvm_sys::{ArgValue, RetValue};

use crate::errors::Error;
use crate::IsObjectRef;
use crate::String as TVMString;

type Deleter = unsafe extern "C" fn(object: *mut Object) -> ();

Expand Down Expand Up @@ -320,6 +322,19 @@ impl<'a, T: IsObject> TryFrom<RetValue> for ObjectPtr<T> {
debug_assert!(optr.count() >= 1);
optr.upcast::<Object>().downcast()
}
RetValue::String(_) | RetValue::Str(_) => {
let string: String = ret_value.try_into().expect("Known to contain a string");

let string: TVMString = string.into();

let string = string
.into_ptr()
.expect("Known to contain a non-nullptr string");

debug_assert!(string.count() >= 1);

string.upcast::<Object>().downcast()
}
_ => Err(Error::downcast(format!("{:?}", ret_value), T::TYPE_KEY)),
}
}
Expand Down
17 changes: 3 additions & 14 deletions src/relay/analysis/annotated_region_set.h
Original file line number Diff line number Diff line change
Expand Up @@ -210,13 +210,6 @@ class AnnotatedRegionSet : public ObjectRef {
data_ = std::move(n);
}

/*!
* \brief Construct from an object pointer.
*
* \param n The object pointer.
*/
explicit AnnotatedRegionSet(ObjectPtr<Object> n) : ObjectRef(n) {}

/*! \return The begin iterator. */
iterator begin() {
auto* n = operator->();
Expand All @@ -242,13 +235,6 @@ class AnnotatedRegionSet : public ObjectRef {
return n->end();
}

/*! \return mutable pointers to the node. */
AnnotatedRegionSetNode* operator->() const {
auto* ptr = get_mutable();
ICHECK(ptr != nullptr);
return static_cast<AnnotatedRegionSetNode*>(ptr);
}

/*! \return The region an expression belongs to. */
AnnotatedRegion operator[](const Expr& expr) {
const auto* n = operator->();
Expand All @@ -268,6 +254,9 @@ class AnnotatedRegionSet : public ObjectRef {
static AnnotatedRegionSet Create(const Expr& expr, const Op& begin, const Op& end,
const std::string& func_name = "default");

TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(AnnotatedRegionSet, ObjectRef,
AnnotatedRegionSetNode);

private:
/*! \brief Helper class to construct a RegionSet from an expr.*/
class Creator;
Expand Down
13 changes: 1 addition & 12 deletions src/relay/analysis/call_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -207,12 +207,6 @@ class CallGraph : public ObjectRef {
*/
explicit CallGraph(IRModule module);

/*!
* \brief Construct from an object pointer.
* \param n The object pointer.
*/
explicit CallGraph(ObjectPtr<Object> n) : ObjectRef(n) {}

/*! \return The begin iterator. */
iterator begin() {
auto* n = operator->();
Expand Down Expand Up @@ -287,12 +281,7 @@ class CallGraph : public ObjectRef {
return (*n)[gvar_name];
}

/*! \return mutable pointers to the node. */
CallGraphNode* operator->() const {
auto* ptr = get_mutable();
ICHECK(ptr != nullptr);
return static_cast<CallGraphNode*>(ptr);
}
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(CallGraph, ObjectRef, CallGraphNode);

private:
/*! \brief Overload the << operator to print a call graph. */
Expand Down
8 changes: 8 additions & 0 deletions src/runtime/minrpc/rpc_reference.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,14 @@ class Object;
/*! \brief The current RPC procotol version. */
constexpr const char* kRPCProtocolVer = "0.8.0";

/*!
* \brief type index of kRuntimeString
* \note this needs to be kept consistent with runtime/object.h
* but we explicitly declare it here because minrpc needs to be minimum dep
* only c C API
*/
constexpr const int kRuntimeString = 3;

/*!
* \brief type index of kRuntimeRPCObjectRefTypeIndex
* \note this needs to be kept consistent with runtime/object.h
Expand Down
4 changes: 4 additions & 0 deletions src/runtime/rpc/rpc_local_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ RPCSession::PackedFuncHandle LocalSession::GetFunction(const std::string& name)
}

void LocalSession::EncodeReturn(TVMRetValue rv, const FEncodeReturn& encode_return) {
if (rv.type_code() == kTVMObjectHandle && rv.IsObjectRef<runtime::String>()) {
rv = std::string(rv.AsObjectRef<runtime::String>());
}

int rv_tcode = rv.type_code();

// return value encoding.
Expand Down
Loading