Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 136 additions & 0 deletions include/tvm/runtime/data_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <tvm/runtime/c_runtime_api.h>
#include <dmlc/logging.h>
#include <type_traits>
#include <string>

namespace tvm {
namespace runtime {
Expand Down Expand Up @@ -263,6 +264,141 @@ inline bool TypeMatch(DLDataType t, int code, int bits, int lanes = 1) {
inline bool TypeEqual(DLDataType lhs, DLDataType rhs) {
return lhs.code == rhs.code && lhs.bits == rhs.bits && lhs.lanes == rhs.lanes;
}

/*!
* \brief Runtime utility for getting custom type name from code
* \param type_code Custom type code
* \return Custom type name
*/
TVM_DLL std::string GetCustomTypeName(uint8_t type_code);

/*!
* \brief Runtime utility for checking whether custom type is registered
* \param type_code Custom type code
* \return Bool representing whether type is registered
*/
TVM_DLL bool GetCustomTypeRegistered(uint8_t type_code);

/*!
* \brief Runtime utility for parsing string of the form "custom[<typename>]"
* \param s String to parse
* \param scan pointer to parsing pointer, which is scanning across s
* \return type code of custom type parsed
*/
TVM_DLL uint8_t ParseCustomDatatype(const std::string& s, const char** scan);

/*!
* \brief Convert type code to its name
* \param type_code The type code .
* \return The name of type code.
*/
inline const char* TypeCode2Str(int type_code);

/*!
* \brief convert a string to TVM type.
* \param s The string to be converted.
* \return The corresponding tvm type.
*/
inline DLDataType String2DLDataType(std::string s);

/*!
* \brief convert a TVM type to string.
* \param t The type to be converted.
* \return The corresponding tvm type in string.
*/
inline std::string DLDataType2String(DLDataType t);

// implementation details
inline const char* TypeCode2Str(int type_code) {
switch (type_code) {
case kDLInt: return "int";
case kDLUInt: return "uint";
case kDLFloat: return "float";
case kTVMStr: return "str";
case kTVMBytes: return "bytes";
case kTVMOpaqueHandle: return "handle";
case kTVMNullptr: return "NULL";
case kTVMDLTensorHandle: return "ArrayHandle";
case kTVMDataType: return "DLDataType";
case kTVMContext: return "TVMContext";
case kTVMPackedFuncHandle: return "FunctionHandle";
case kTVMModuleHandle: return "ModuleHandle";
case kTVMNDArrayHandle: return "NDArrayContainer";
case kTVMObjectHandle: return "Object";
default: LOG(FATAL) << "unknown type_code="
<< static_cast<int>(type_code); return "";
}
}

inline std::ostream& operator<<(std::ostream& os, DLDataType t) { // NOLINT(*)
if (t.bits == 1 && t.lanes == 1 && t.code == kDLUInt) {
os << "bool"; return os;
}
if (t.code < kTVMCustomBegin) {
os << TypeCode2Str(t.code);
} else {
os << "custom[" << GetCustomTypeName(t.code) << "]";
}
if (t.code == kTVMOpaqueHandle) return os;
os << static_cast<int>(t.bits);
if (t.lanes != 1) {
os << 'x' << static_cast<int>(t.lanes);
}
return os;
}

inline std::ostream& operator<<(std::ostream& os, const DataType& dtype) { // NOLINT(*)
return os << dtype.operator DLDataType();
}

inline std::string DLDataType2String(DLDataType t) {
if (t.bits == 0) return "";
std::ostringstream os;
os << t;
return os.str();
}

inline DLDataType String2DLDataType(std::string s) {
DLDataType t;
// handle None type
if (s.length() == 0) {
t.bits = 0; t.lanes = 0; t.code = kTVMOpaqueHandle;
return t;
}
t.bits = 32; t.lanes = 1;
const char* scan;
if (s.substr(0, 3) == "int") {
t.code = kDLInt; scan = s.c_str() + 3;
} else if (s.substr(0, 4) == "uint") {
t.code = kDLUInt; scan = s.c_str() + 4;
} else if (s.substr(0, 5) == "float") {
t.code = kDLFloat; scan = s.c_str() + 5;
} else if (s.substr(0, 6) == "handle") {
t.code = kTVMOpaqueHandle;
t.bits = 64; // handle uses 64 bit by default.
scan = s.c_str() + 6;
} else if (s == "bool") {
t.code = kDLUInt;
t.bits = 1;
t.lanes = 1;
return t;
} else if (s.substr(0, 6) == "custom") {
t.code = ParseCustomDatatype(s, &scan);
} else {
scan = s.c_str();
LOG(FATAL) << "unknown type " << s;
}
char* xdelim; // emulate sscanf("%ux%u", bits, lanes)
uint8_t bits = static_cast<uint8_t>(strtoul(scan, &xdelim, 10));
if (bits != 0) t.bits = bits;
char* endpt = xdelim;
if (*xdelim == 'x') {
t.lanes = static_cast<uint16_t>(strtoul(xdelim + 1, &endpt, 10));
}
CHECK(endpt == s.c_str() + s.length()) << "unknown type " << s;
return t;
}

} // namespace runtime

using DataType = runtime::DataType;
Expand Down
142 changes: 9 additions & 133 deletions include/tvm/runtime/packed_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include <tvm/runtime/ndarray.h>
#include <tvm/runtime/data_type.h>
#include <tvm/runtime/object.h>
#include <tvm/runtime/container.h>
#include <functional>
#include <tuple>
#include <vector>
Expand All @@ -52,28 +53,6 @@ class PrimExpr;

namespace runtime {

/*!
* \brief Runtime utility for getting custom type name from code
* \param type_code Custom type code
* \return Custom type name
*/
TVM_DLL std::string GetCustomTypeName(uint8_t type_code);

/*!
* \brief Runtime utility for checking whether custom type is registered
* \param type_code Custom type code
* \return Bool representing whether type is registered
*/
TVM_DLL bool GetCustomTypeRegistered(uint8_t type_code);

/*!
* \brief Runtime utility for parsing string of the form "custom[<typename>]"
* \param s String to parse
* \param scan pointer to parsing pointer, which is scanning across s
* \return type code of custom type parsed
*/
TVM_DLL uint8_t ParseCustomDatatype(const std::string& s, const char** scan);

// forward declarations
class TVMArgs;
class TVMArgValue;
Expand Down Expand Up @@ -359,27 +338,6 @@ class TVMArgs {
inline TVMArgValue operator[](int i) const;
};

/*!
* \brief Convert type code to its name
* \param type_code The type code .
* \return The name of type code.
*/
inline const char* TypeCode2Str(int type_code);

/*!
* \brief convert a string to TVM type.
* \param s The string to be converted.
* \return The corresponding tvm type.
*/
inline DLDataType String2DLDataType(std::string s);

/*!
* \brief convert a TVM type to string.
* \param t The type to be converted.
* \return The corresponding tvm type in string.
*/
inline std::string DLDataType2String(DLDataType t);

// macro to check type code.
#define TVM_CHECK_TYPE_CODE(CODE, T) \
CHECK_EQ(CODE, T) << " expected " \
Expand Down Expand Up @@ -554,6 +512,10 @@ class TVMArgValue : public TVMPODValue_ {
return std::string(value_.v_str);
}
}
operator tvm::runtime::String() const {
// directly use the std::string constructor for now.
return tvm::runtime::String(operator std::string());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tqchen It happened to me that Line511 above failed for the check because the type_code_ for String is an object. Should we remove this and pass String objectref directly? Or do we need to handle String through FFI?

Copy link
Member Author

@tqchen tqchen Apr 7, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, i see, good catch, we will need to add a patch, by checking if the result is kStr and run this, alternatively, use AsObjectRef

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so let's just return AsObjectRef<tvm::runtime::String>() for now?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

}
operator DLDataType() const {
if (type_code_ == kTVMStr) {
return String2DLDataType(operator std::string());
Expand Down Expand Up @@ -642,6 +604,10 @@ class TVMRetValue : public TVMPODValue_ {
TVM_CHECK_TYPE_CODE(type_code_, kTVMStr);
return *ptr<std::string>();
}
operator tvm::runtime::String() const {
// directly use the std::string constructor for now.
return tvm::runtime::String(operator std::string());
}
operator DLDataType() const {
if (type_code_ == kTVMStr) {
return String2DLDataType(operator std::string());
Expand Down Expand Up @@ -994,96 +960,6 @@ class TVMRetValue : public TVMPODValue_ {
} \
}

// implementation details
inline const char* TypeCode2Str(int type_code) {
switch (type_code) {
case kDLInt: return "int";
case kDLUInt: return "uint";
case kDLFloat: return "float";
case kTVMStr: return "str";
case kTVMBytes: return "bytes";
case kTVMOpaqueHandle: return "handle";
case kTVMNullptr: return "NULL";
case kTVMDLTensorHandle: return "ArrayHandle";
case kTVMDataType: return "DLDataType";
case kTVMContext: return "TVMContext";
case kTVMPackedFuncHandle: return "FunctionHandle";
case kTVMModuleHandle: return "ModuleHandle";
case kTVMNDArrayHandle: return "NDArrayContainer";
case kTVMObjectHandle: return "Object";
default: LOG(FATAL) << "unknown type_code="
<< static_cast<int>(type_code); return "";
}
}

inline std::ostream& operator<<(std::ostream& os, DLDataType t) { // NOLINT(*)
if (t.bits == 1 && t.lanes == 1 && t.code == kDLUInt) {
os << "bool"; return os;
}
if (t.code < kTVMCustomBegin) {
os << TypeCode2Str(t.code);
} else {
os << "custom[" << GetCustomTypeName(t.code) << "]";
}
if (t.code == kTVMOpaqueHandle) return os;
os << static_cast<int>(t.bits);
if (t.lanes != 1) {
os << 'x' << static_cast<int>(t.lanes);
}
return os;
}

inline std::ostream& operator<<(std::ostream& os, const DataType& dtype) { // NOLINT(*)
return os << dtype.operator DLDataType();
}

inline std::string DLDataType2String(DLDataType t) {
if (t.bits == 0) return "";
std::ostringstream os;
os << t;
return os.str();
}

inline DLDataType String2DLDataType(std::string s) {
DLDataType t;
// handle None type
if (s.length() == 0) {
t.bits = 0; t.lanes = 0; t.code = kTVMOpaqueHandle;
return t;
}
t.bits = 32; t.lanes = 1;
const char* scan;
if (s.substr(0, 3) == "int") {
t.code = kDLInt; scan = s.c_str() + 3;
} else if (s.substr(0, 4) == "uint") {
t.code = kDLUInt; scan = s.c_str() + 4;
} else if (s.substr(0, 5) == "float") {
t.code = kDLFloat; scan = s.c_str() + 5;
} else if (s.substr(0, 6) == "handle") {
t.code = kTVMOpaqueHandle;
t.bits = 64; // handle uses 64 bit by default.
scan = s.c_str() + 6;
} else if (s == "bool") {
t.code = kDLUInt;
t.bits = 1;
t.lanes = 1;
return t;
} else if (s.substr(0, 6) == "custom") {
t.code = ParseCustomDatatype(s, &scan);
} else {
scan = s.c_str();
LOG(FATAL) << "unknown type " << s;
}
char* xdelim; // emulate sscanf("%ux%u", bits, lanes)
uint8_t bits = static_cast<uint8_t>(strtoul(scan, &xdelim, 10));
if (bits != 0) t.bits = bits;
char* endpt = xdelim;
if (*xdelim == 'x') {
t.lanes = static_cast<uint16_t>(strtoul(xdelim + 1, &endpt, 10));
}
CHECK(endpt == s.c_str() + s.length()) << "unknown type " << s;
return t;
}

inline TVMArgValue TVMArgs::operator[](int i) const {
CHECK_LT(i, num_args)
Expand Down
2 changes: 2 additions & 0 deletions tests/cpp/packed_func_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ TEST(PackedFunc, str) {
CHECK(args.num_args == 1);
std::string x = args[0];
CHECK(x == "hello");
String y = args[0];
CHECK(y == "hello");
*rv = x;
})("hello");
}
Expand Down