diff --git a/rust/.gitignore b/rust/.gitignore new file mode 100644 index 000000000000..0cc660650780 --- /dev/null +++ b/rust/.gitignore @@ -0,0 +1,4 @@ +target/ +*.rs.bk +Cargo.lock +c_runtime_api.rs diff --git a/rust/common/.gitignore b/rust/common/.gitignore deleted file mode 100644 index 84c2ae99001c..000000000000 --- a/rust/common/.gitignore +++ /dev/null @@ -1,4 +0,0 @@ -target -**/*.rs.bk -Cargo.lock -/tvm-sys/src/bindgen.rs diff --git a/rust/common/Cargo.toml b/rust/common/Cargo.toml index bcba5ad62fc9..5d21ee509b02 100644 --- a/rust/common/Cargo.toml +++ b/rust/common/Cargo.toml @@ -5,9 +5,11 @@ authors = ["TVM Contributors"] license = "Apache-2.0" [features] -runtime = [] -frontend = ["tvm-sys"] +bindings = [] [dependencies] -error-chain = { version = "0.12.0", default-features = false } -tvm-sys = { version = "0.1.0", path = "tvm-sys", optional = true } +failure = "0.1.5" +ndarray = "0.12.1" + +[build-dependencies] +bindgen = "0.37.4" diff --git a/rust/common/build.rs b/rust/common/build.rs new file mode 100644 index 000000000000..f07e71f0f2bb --- /dev/null +++ b/rust/common/build.rs @@ -0,0 +1,31 @@ +extern crate bindgen; + +use std::path::PathBuf; + +fn main() { + if cfg!(feature = "bindings") { + println!("cargo:rerun-if-env-changed=TVM_HOME"); + println!("cargo:rustc-link-lib=dylib=tvm_runtime"); + println!("cargo:rustc-link-search={}/build", env!("TVM_HOME")); + } + + // @see rust-bindgen#550 for `blacklist_type` + bindgen::Builder::default() + .header(format!( + "{}/include/tvm/runtime/c_runtime_api.h", + env!("TVM_HOME") + )) + .header(format!( + "{}/include/tvm/runtime/c_backend_api.h", + env!("TVM_HOME") + )) + .clang_arg(format!("-I{}/3rdparty/dlpack/include/", env!("TVM_HOME"))) + .blacklist_type("max_align_t") + .layout_tests(false) + .derive_partialeq(true) + .derive_eq(true) + .generate() + .expect("unable to generate bindings") + .write_to_file(PathBuf::from("src/c_runtime_api.rs")) + .expect("can not write the bindings!"); +} diff --git a/rust/common/src/array.rs b/rust/common/src/array.rs new file mode 100644 index 000000000000..e7b75850677d --- /dev/null +++ b/rust/common/src/array.rs @@ -0,0 +1,128 @@ +use std::{ + any::TypeId, + mem, + os::raw::{c_int, c_void}, +}; + +use crate::ffi::{ + DLContext, DLDataType, DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLDataTypeCode_kDLUInt, + DLDeviceType_kDLCPU, DLTensor, +}; + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct DataType { + pub code: usize, + pub bits: usize, + pub lanes: usize, +} + +impl DataType { + /// Returns the number of bytes occupied by an element of this `DataType`. + pub fn itemsize(&self) -> usize { + (self.bits * self.lanes) >> 3 + } + + /// Returns whether this `DataType` represents primitive type `T`. + pub fn is_type(&self) -> bool { + if self.lanes != 1 { + return false; + } + let typ = TypeId::of::(); + (typ == TypeId::of::() && self.code == 0 && self.bits == 32) + || (typ == TypeId::of::() && self.code == 0 && self.bits == 64) + || (typ == TypeId::of::() && self.code == 1 && self.bits == 32) + || (typ == TypeId::of::() && self.code == 1 && self.bits == 64) + || (typ == TypeId::of::() && self.code == 2 && self.bits == 32) + || (typ == TypeId::of::() && self.code == 2 && self.bits == 64) + } + + pub fn code(&self) -> usize { + self.code + } + + pub fn bits(&self) -> usize { + self.bits + } + + pub fn lanes(&self) -> usize { + self.lanes + } +} + +impl<'a> From<&'a DataType> for DLDataType { + fn from(dtype: &'a DataType) -> Self { + Self { + code: dtype.code as u8, + bits: dtype.bits as u8, + lanes: dtype.lanes as u16, + } + } +} + +impl From for DataType { + fn from(dtype: DLDataType) -> Self { + Self { + code: dtype.code as usize, + bits: dtype.bits as usize, + lanes: dtype.lanes as usize, + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq)] +pub struct TVMContext { + pub device_type: usize, + pub device_id: usize, +} + +impl<'a> From<&'a TVMContext> for DLContext { + fn from(ctx: &'a TVMContext) -> Self { + Self { + device_type: ctx.device_type as u32, + device_id: ctx.device_id as i32, + } + } +} + +impl Default for TVMContext { + fn default() -> Self { + Self { + device_type: DLDeviceType_kDLCPU as usize, + device_id: 0, + } + } +} + +/// `From` conversions to `DLTensor` for `ndarray::Array`. +/// Takes a reference to the `ndarray` since `DLTensor` is not owned. +macro_rules! impl_dltensor_from_ndarray { + ($type:ty, $typecode:expr) => { + impl<'a, D: ndarray::Dimension> From<&'a mut ndarray::Array<$type, D>> for DLTensor { + fn from(arr: &'a mut ndarray::Array<$type, D>) -> Self { + DLTensor { + data: arr.as_mut_ptr() as *mut c_void, + ctx: DLContext { + device_type: DLDeviceType_kDLCPU, + device_id: 0, + }, + ndim: arr.ndim() as c_int, + dtype: DLDataType { + code: $typecode as u8, + bits: 8 * mem::size_of::<$type>() as u8, + lanes: 1, + }, + shape: arr.shape().as_ptr() as *const i64 as *mut i64, + strides: arr.strides().as_ptr() as *const isize as *mut i64, + byte_offset: 0, + } + } + } + }; +} + +impl_dltensor_from_ndarray!(f32, DLDataTypeCode_kDLFloat); +impl_dltensor_from_ndarray!(f64, DLDataTypeCode_kDLFloat); +impl_dltensor_from_ndarray!(i32, DLDataTypeCode_kDLInt); +impl_dltensor_from_ndarray!(i64, DLDataTypeCode_kDLInt); +impl_dltensor_from_ndarray!(u32, DLDataTypeCode_kDLUInt); +impl_dltensor_from_ndarray!(u64, DLDataTypeCode_kDLUInt); diff --git a/rust/common/src/c_runtime_api.rs b/rust/common/src/c_runtime_api.rs deleted file mode 100644 index 6facf9ca274f..000000000000 --- a/rust/common/src/c_runtime_api.rs +++ /dev/null @@ -1,770 +0,0 @@ -/* automatically generated by rust-bindgen for TVM revision 6292c78 */ - -pub const TVM_VERSION: &'static [u8; 8usize] = b"0.5.dev\0"; -pub const DLPACK_VERSION: u32 = 8; -pub const _STDINT_H: u32 = 1; -pub const _FEATURES_H: u32 = 1; -pub const _DEFAULT_SOURCE: u32 = 1; -pub const __USE_ISOC11: u32 = 1; -pub const __USE_ISOC99: u32 = 1; -pub const __USE_ISOC95: u32 = 1; -pub const __USE_POSIX_IMPLICITLY: u32 = 1; -pub const _POSIX_SOURCE: u32 = 1; -pub const _POSIX_C_SOURCE: u32 = 200809; -pub const __USE_POSIX: u32 = 1; -pub const __USE_POSIX2: u32 = 1; -pub const __USE_POSIX199309: u32 = 1; -pub const __USE_POSIX199506: u32 = 1; -pub const __USE_XOPEN2K: u32 = 1; -pub const __USE_XOPEN2K8: u32 = 1; -pub const _ATFILE_SOURCE: u32 = 1; -pub const __USE_MISC: u32 = 1; -pub const __USE_ATFILE: u32 = 1; -pub const __USE_FORTIFY_LEVEL: u32 = 0; -pub const _STDC_PREDEF_H: u32 = 1; -pub const __STDC_IEC_559__: u32 = 1; -pub const __STDC_IEC_559_COMPLEX__: u32 = 1; -pub const __STDC_ISO_10646__: u32 = 201505; -pub const __STDC_NO_THREADS__: u32 = 1; -pub const __GNU_LIBRARY__: u32 = 6; -pub const __GLIBC__: u32 = 2; -pub const __GLIBC_MINOR__: u32 = 23; -pub const _SYS_CDEFS_H: u32 = 1; -pub const __WORDSIZE: u32 = 64; -pub const __WORDSIZE_TIME64_COMPAT32: u32 = 1; -pub const __SYSCALL_WORDSIZE: u32 = 64; -pub const _BITS_WCHAR_H: u32 = 1; -pub const INT8_MIN: i32 = -128; -pub const INT16_MIN: i32 = -32768; -pub const INT32_MIN: i32 = -2147483648; -pub const INT8_MAX: u32 = 127; -pub const INT16_MAX: u32 = 32767; -pub const INT32_MAX: u32 = 2147483647; -pub const UINT8_MAX: u32 = 255; -pub const UINT16_MAX: u32 = 65535; -pub const UINT32_MAX: u32 = 4294967295; -pub const INT_LEAST8_MIN: i32 = -128; -pub const INT_LEAST16_MIN: i32 = -32768; -pub const INT_LEAST32_MIN: i32 = -2147483648; -pub const INT_LEAST8_MAX: u32 = 127; -pub const INT_LEAST16_MAX: u32 = 32767; -pub const INT_LEAST32_MAX: u32 = 2147483647; -pub const UINT_LEAST8_MAX: u32 = 255; -pub const UINT_LEAST16_MAX: u32 = 65535; -pub const UINT_LEAST32_MAX: u32 = 4294967295; -pub const INT_FAST8_MIN: i32 = -128; -pub const INT_FAST16_MIN: i64 = -9223372036854775808; -pub const INT_FAST32_MIN: i64 = -9223372036854775808; -pub const INT_FAST8_MAX: u32 = 127; -pub const INT_FAST16_MAX: u64 = 9223372036854775807; -pub const INT_FAST32_MAX: u64 = 9223372036854775807; -pub const UINT_FAST8_MAX: u32 = 255; -pub const UINT_FAST16_MAX: i32 = -1; -pub const UINT_FAST32_MAX: i32 = -1; -pub const INTPTR_MIN: i64 = -9223372036854775808; -pub const INTPTR_MAX: u64 = 9223372036854775807; -pub const UINTPTR_MAX: i32 = -1; -pub const PTRDIFF_MIN: i64 = -9223372036854775808; -pub const PTRDIFF_MAX: u64 = 9223372036854775807; -pub const SIG_ATOMIC_MIN: i32 = -2147483648; -pub const SIG_ATOMIC_MAX: u32 = 2147483647; -pub const SIZE_MAX: i32 = -1; -pub const WINT_MIN: u32 = 0; -pub const WINT_MAX: u32 = 4294967295; -pub type int_least8_t = ::std::os::raw::c_schar; -pub type int_least16_t = ::std::os::raw::c_short; -pub type int_least32_t = ::std::os::raw::c_int; -pub type int_least64_t = ::std::os::raw::c_long; -pub type uint_least8_t = ::std::os::raw::c_uchar; -pub type uint_least16_t = ::std::os::raw::c_ushort; -pub type uint_least32_t = ::std::os::raw::c_uint; -pub type uint_least64_t = ::std::os::raw::c_ulong; -pub type int_fast8_t = ::std::os::raw::c_schar; -pub type int_fast16_t = ::std::os::raw::c_long; -pub type int_fast32_t = ::std::os::raw::c_long; -pub type int_fast64_t = ::std::os::raw::c_long; -pub type uint_fast8_t = ::std::os::raw::c_uchar; -pub type uint_fast16_t = ::std::os::raw::c_ulong; -pub type uint_fast32_t = ::std::os::raw::c_ulong; -pub type uint_fast64_t = ::std::os::raw::c_ulong; -pub type intmax_t = ::std::os::raw::c_long; -pub type uintmax_t = ::std::os::raw::c_ulong; -pub type wchar_t = ::std::os::raw::c_int; -#[repr(C)] -#[derive(Debug, Copy, Clone)] -pub struct max_align_t { - pub __clang_max_align_nonce1: ::std::os::raw::c_longlong, - pub __bindgen_padding_0: u64, - pub __clang_max_align_nonce2: f64, -} -pub const DLDeviceType_kDLCPU: DLDeviceType = 1; -pub const DLDeviceType_kDLGPU: DLDeviceType = 2; -pub const DLDeviceType_kDLCPUPinned: DLDeviceType = 3; -pub const DLDeviceType_kDLOpenCL: DLDeviceType = 4; -pub const DLDeviceType_kDLMetal: DLDeviceType = 8; -pub const DLDeviceType_kDLVPI: DLDeviceType = 9; -pub const DLDeviceType_kDLROCM: DLDeviceType = 10; -/// \brief The device type in DLContext. -pub type DLDeviceType = u32; -/// \brief A Device context for Tensor and operator. -#[repr(C)] -#[derive(Debug, Copy, Clone)] -pub struct DLContext { - /// \brief The device type used in the device. - pub device_type: DLDeviceType, - /// \brief The device index - pub device_id: ::std::os::raw::c_int, -} -pub const DLDataTypeCode_kDLInt: DLDataTypeCode = 0; -pub const DLDataTypeCode_kDLUInt: DLDataTypeCode = 1; -pub const DLDataTypeCode_kDLFloat: DLDataTypeCode = 2; -/// \brief The type code options DLDataType. -pub type DLDataTypeCode = u32; -/// \brief The data type the tensor can hold. -/// -/// Examples -/// - float: type_code = 2, bits = 32, lanes=1 -/// - float4(vectorized 4 float): type_code = 2, bits = 32, lanes=4 -/// - int8: type_code = 0, bits = 8, lanes=1 -#[repr(C)] -#[derive(Debug, Copy, Clone)] -pub struct DLDataType { - /// \brief Type code of base types. - /// We keep it uint8_t instead of DLDataTypeCode for minimal memory - /// footprint, but the value should be one of DLDataTypeCode enum values. - /// - pub code: u8, - /// \brief Number of bits, common choices are 8, 16, 32. - pub bits: u8, - /// \brief Number of lanes in the type, used for vector types. - pub lanes: u16, -} -/// \brief Plain C Tensor object, does not manage memory. -#[repr(C)] -#[derive(Debug, Copy, Clone)] -pub struct DLTensor { - /// \brief The opaque data pointer points to the allocated data. - /// This will be CUDA device pointer or cl_mem handle in OpenCL. - /// This pointer is always aligns to 256 bytes as in CUDA. - pub data: *mut ::std::os::raw::c_void, - /// \brief The device context of the tensor - pub ctx: DLContext, - /// \brief Number of dimensions - pub ndim: ::std::os::raw::c_int, - /// \brief The data type of the pointer - pub dtype: DLDataType, - /// \brief The shape of the tensor - pub shape: *mut i64, - /// \brief strides of the tensor, - /// can be NULL, indicating tensor is compact. - pub strides: *mut i64, - /// \brief The offset in bytes to the beginning pointer to data - pub byte_offset: u64, -} -/// \brief C Tensor object, manage memory of DLTensor. This data structure is -/// intended to faciliate the borrowing of DLTensor by another framework. It is -/// not meant to transfer the tensor. When the borrowing framework doesn't need -/// the tensor, it should call the deleter to notify the host that the resource -/// is no longer needed. -#[repr(C)] -#[derive(Debug, Copy, Clone)] -pub struct DLManagedTensor { - /// \brief DLTensor which is being memory managed - pub dl_tensor: DLTensor, - /// \brief the context of the original host framework of DLManagedTensor in - /// which DLManagedTensor is used in the framework. It can also be NULL. - pub manager_ctx: *mut ::std::os::raw::c_void, - /// \brief Destructor signature void (*)(void*) - this should be called - /// to destruct manager_ctx which holds the DLManagedTensor. It can be NULL - /// if there is no way for the caller to provide a reasonable destructor. - pub deleter: ::std::option::Option, -} -/// \brief type of array index. -pub type tvm_index_t = i64; -pub const TVMDeviceExtType_kDLAOCL: TVMDeviceExtType = 5; -pub const TVMDeviceExtType_kDLSDAccel: TVMDeviceExtType = 6; -pub const TVMDeviceExtType_kDLVulkan: TVMDeviceExtType = 7; -pub const TVMDeviceExtType_kOpenGL: TVMDeviceExtType = 11; -pub const TVMDeviceExtType_kExtDev: TVMDeviceExtType = 12; -/// \brief Extension device types in TVM -pub type TVMDeviceExtType = u32; -pub const TVMTypeCode_kHandle: TVMTypeCode = 3; -pub const TVMTypeCode_kNull: TVMTypeCode = 4; -pub const TVMTypeCode_kTVMType: TVMTypeCode = 5; -pub const TVMTypeCode_kTVMContext: TVMTypeCode = 6; -pub const TVMTypeCode_kArrayHandle: TVMTypeCode = 7; -pub const TVMTypeCode_kNodeHandle: TVMTypeCode = 8; -pub const TVMTypeCode_kModuleHandle: TVMTypeCode = 9; -pub const TVMTypeCode_kFuncHandle: TVMTypeCode = 10; -pub const TVMTypeCode_kStr: TVMTypeCode = 11; -pub const TVMTypeCode_kBytes: TVMTypeCode = 12; -pub const TVMTypeCode_kNDArrayContainer: TVMTypeCode = 13; -pub const TVMTypeCode_kExtBegin: TVMTypeCode = 15; -pub const TVMTypeCode_kNNVMFirst: TVMTypeCode = 16; -pub const TVMTypeCode_kNNVMLast: TVMTypeCode = 20; -pub const TVMTypeCode_kExtReserveEnd: TVMTypeCode = 64; -pub const TVMTypeCode_kExtEnd: TVMTypeCode = 128; -/// \brief The type code in TVMType -/// \note TVMType is used in two places. -pub type TVMTypeCode = u32; -/// \brief The data type used in TVM Runtime. -/// -/// Examples -/// - float: type_code = 2, bits = 32, lanes=1 -/// - float4(vectorized 4 float): type_code = 2, bits = 32, lanes=4 -/// - int8: type_code = 0, bits = 8, lanes=1 -/// -/// \note Arguments TVM API function always takes bits=64 and lanes=1 -pub type TVMType = DLDataType; -/// \brief The Device information, abstract away common device types. -pub type TVMContext = DLContext; -/// \brief The tensor array stucture to TVM API. -pub type TVMArray = DLTensor; -/// \brief the array handle -pub type TVMArrayHandle = *mut TVMArray; -/// \brief Union type of values -/// being passed through API and function calls. -#[repr(C)] -#[derive(Copy, Clone)] -pub union TVMValue { - pub v_int64: i64, - pub v_float64: f64, - pub v_handle: *mut ::std::os::raw::c_void, - pub v_str: *const ::std::os::raw::c_char, - pub v_type: TVMType, - pub v_ctx: TVMContext, - _bindgen_union_align: u64, -} -/// \brief Byte array type used to pass in byte array -/// When kBytes is used as data type. -#[repr(C)] -#[derive(Debug, Copy, Clone)] -pub struct TVMByteArray { - pub data: *const ::std::os::raw::c_char, - pub size: usize, -} -/// \brief Handle to TVM runtime modules. -pub type TVMModuleHandle = *mut ::std::os::raw::c_void; -/// \brief Handle to packed function handle. -pub type TVMFunctionHandle = *mut ::std::os::raw::c_void; -/// \brief Handle to hold return value. -pub type TVMRetValueHandle = *mut ::std::os::raw::c_void; -/// \brief The stream that is specific to device -/// can be NULL, which indicates the default one. -pub type TVMStreamHandle = *mut ::std::os::raw::c_void; -extern "C" { - /// \brief Used for implementing C API function. - /// Set last error message before return. - /// \param msg The error message to be set. - pub fn TVMAPISetLastError(msg: *const ::std::os::raw::c_char); -} -extern "C" { - /// \brief return str message of the last error - /// all function in this file will return 0 when success - /// and -1 when an error occured, - /// TVMGetLastError can be called to retrieve the error - /// - /// this function is threadsafe and can be called by different thread - /// \return error info - pub fn TVMGetLastError() -> *const ::std::os::raw::c_char; -} -extern "C" { - /// \brief Load module from file. - /// \param file_name The file name to load the module from. - /// \param format The format of the module. - /// \param out The result module - /// - /// \return 0 when success, -1 when failure happens - /// \note The resulting module do not contain import relation. - /// It can be reconstructed by TVMModImport. - pub fn TVMModLoadFromFile( - file_name: *const ::std::os::raw::c_char, - format: *const ::std::os::raw::c_char, - out: *mut TVMModuleHandle, - ) -> ::std::os::raw::c_int; -} -extern "C" { - /// \brief Add dep to mod's dependency. - /// This allows functions in this module to use modules. - /// - /// \param mod The module handle. - /// \param dep The dependent module to be imported. - /// \return 0 when success, -1 when failure happens - pub fn TVMModImport(mod_: TVMModuleHandle, dep: TVMModuleHandle) -> ::std::os::raw::c_int; -} -extern "C" { - /// \brief Get function from the module. - /// \param mod The module handle. - /// \param func_name The name of the function. - /// \param query_imports Whether to query imported modules - /// \param out The result function, can be NULL if it is not available. - /// \return 0 when no error is thrown, -1 when failure happens - pub fn TVMModGetFunction( - mod_: TVMModuleHandle, - func_name: *const ::std::os::raw::c_char, - query_imports: ::std::os::raw::c_int, - out: *mut TVMFunctionHandle, - ) -> ::std::os::raw::c_int; -} -extern "C" { - /// \brief Free front-end extension type resource. - /// \param handle The extension handle. - /// \param type_code The type of of the extension type. - /// \return 0 when success, -1 when failure happens - pub fn TVMExtTypeFree( - handle: *mut ::std::os::raw::c_void, - type_code: ::std::os::raw::c_int, - ) -> ::std::os::raw::c_int; -} -extern "C" { - /// \brief Free the Module - /// \param mod The module to be freed. - /// - /// \note This may not free up the module's resources. - /// If there is active TVMFunctionHandle uses the module - /// Or if this module is imported by another active module. - /// - /// The all functions remains valid until TVMFuncFree is called. - /// \return 0 when success, -1 when failure happens - pub fn TVMModFree(mod_: TVMModuleHandle) -> ::std::os::raw::c_int; -} -extern "C" { - /// \brief Free the function when it is no longer needed. - /// \param func The function handle - /// \return 0 when success, -1 when failure happens - pub fn TVMFuncFree(func: TVMFunctionHandle) -> ::std::os::raw::c_int; -} -extern "C" { - /// \brief Call a Packed TVM Function. - /// - /// \param func node handle of the function. - /// \param arg_values The arguments - /// \param type_codes The type codes of the arguments - /// \param num_args Number of arguments. - /// - /// \param ret_val The return value. - /// \param ret_type_code the type code of return value. - /// - /// \return 0 when success, -1 when failure happens - /// \note TVM calls always exchanges with type bits=64, lanes=1 - /// - /// \note API calls always exchanges with type bits=64, lanes=1 - /// If API call returns container handles (e.g. FunctionHandle) - /// these handles should be managed by the front-end. - /// The front-end need to call free function (e.g. TVMFuncFree) - /// to free these handles. - pub fn TVMFuncCall( - func: TVMFunctionHandle, - arg_values: *mut TVMValue, - type_codes: *mut ::std::os::raw::c_int, - num_args: ::std::os::raw::c_int, - ret_val: *mut TVMValue, - ret_type_code: *mut ::std::os::raw::c_int, - ) -> ::std::os::raw::c_int; -} -extern "C" { - /// \brief Set the return value of TVMPackedCFunc. - /// - /// This function is called by TVMPackedCFunc to set the return value. - /// When this function is not called, the function returns null by default. - /// - /// \param ret The return value handle, pass by ret in TVMPackedCFunc - /// \param value The value to be returned. - /// \param type_code The type of the value to be returned. - /// \param num_ret Number of return values, for now only 1 is supported. - pub fn TVMCFuncSetReturn( - ret: TVMRetValueHandle, - value: *mut TVMValue, - type_code: *mut ::std::os::raw::c_int, - num_ret: ::std::os::raw::c_int, - ) -> ::std::os::raw::c_int; -} -extern "C" { - /// \brief Inplace translate callback argument value to return value. - /// This is only needed for non-POD arguments. - /// - /// \param value The value to be translated. - /// \param code The type code to be translated. - /// \note This function will do a shallow copy when necessary. - /// - /// \return 0 when success, -1 when failure happens. - pub fn TVMCbArgToReturn( - value: *mut TVMValue, - code: ::std::os::raw::c_int, - ) -> ::std::os::raw::c_int; -} -/// \brief C type of packed function. -/// -/// \param args The arguments -/// \param type_codes The type codes of the arguments -/// \param num_args Number of arguments. -/// \param ret The return value handle. -/// \param resource_handle The handle additional resouce handle from fron-end. -/// \return 0 if success, -1 if failure happens, set error via TVMAPISetLastError. -/// \sa TVMCFuncSetReturn -pub type TVMPackedCFunc = ::std::option::Option< - unsafe extern "C" fn( - args: *mut TVMValue, - type_codes: *mut ::std::os::raw::c_int, - num_args: ::std::os::raw::c_int, - ret: TVMRetValueHandle, - resource_handle: *mut ::std::os::raw::c_void, - ) -> ::std::os::raw::c_int, ->; -/// \brief C callback to free the resource handle in C packed function. -/// \param resource_handle The handle additional resouce handle from fron-end. -pub type TVMPackedCFuncFinalizer = - ::std::option::Option; -/// \brief Signature for extension function declarer. -/// -/// TVM call this function to get the extension functions -/// The declarer will call register_func to register function and their name. -/// -/// \param register_func_handle The register function -/// \return 0 if success, -1 if failure happens -pub type TVMExtensionFuncDeclarer = ::std::option::Option< - unsafe extern "C" fn(register_func_handle: TVMFunctionHandle) -> ::std::os::raw::c_int, ->; -extern "C" { - /// \brief Wrap a TVMPackedCFunc to become a FunctionHandle. - /// - /// The resource_handle will be managed by TVM API, until the function is no longer used. - /// - /// \param func The packed C function. - /// \param resource_handle The resource handle from front-end, can be NULL. - /// \param fin The finalizer on resource handle when the FunctionHandle get freed, can be NULL - /// \param out the result function handle. - /// \return 0 when success, -1 when failure happens - pub fn TVMFuncCreateFromCFunc( - func: TVMPackedCFunc, - resource_handle: *mut ::std::os::raw::c_void, - fin: TVMPackedCFuncFinalizer, - out: *mut TVMFunctionHandle, - ) -> ::std::os::raw::c_int; -} -extern "C" { - /// \brief Register the function to runtime's global table. - /// - /// The registered function then can be pulled by the backend by the name. - /// - /// \param name The name of the function. - /// \param f The function to be registered. - /// \param override Whether allow override already registered function. - pub fn TVMFuncRegisterGlobal( - name: *const ::std::os::raw::c_char, - f: TVMFunctionHandle, - override_: ::std::os::raw::c_int, - ) -> ::std::os::raw::c_int; -} -extern "C" { - /// \brief Get a global function. - /// - /// \param name The name of the function. - /// \param out the result function pointer, NULL if it does not exist. - /// - /// \note The function handle of global function is managed by TVM runtime, - /// So TVMFuncFree is should not be called when it get deleted. - pub fn TVMFuncGetGlobal( - name: *const ::std::os::raw::c_char, - out: *mut TVMFunctionHandle, - ) -> ::std::os::raw::c_int; -} -extern "C" { - /// \brief List all the globally registered function name - /// \param out_size The number of functions - /// \param out_array The array of function names. - /// \return 0 when success, -1 when failure happens - pub fn TVMFuncListGlobalNames( - out_size: *mut ::std::os::raw::c_int, - out_array: *mut *mut *const ::std::os::raw::c_char, - ) -> ::std::os::raw::c_int; -} -extern "C" { - /// \brief Allocate a nd-array's memory, - /// including space of shape, of given spec. - /// - /// \param shape The shape of the array, the data content will be copied to out - /// \param ndim The number of dimension of the array. - /// \param dtype_code The type code of the dtype - /// \param dtype_bits The number of bits of dtype - /// \param dtype_lanes The number of lanes in the dtype. - /// \param device_type The device type of context - /// \param device_id The device id of context. - /// \param out The output handle. - /// \return 0 when success, -1 when failure happens - pub fn TVMArrayAlloc( - shape: *const tvm_index_t, - ndim: ::std::os::raw::c_int, - dtype_code: ::std::os::raw::c_int, - dtype_bits: ::std::os::raw::c_int, - dtype_lanes: ::std::os::raw::c_int, - device_type: ::std::os::raw::c_int, - device_id: ::std::os::raw::c_int, - out: *mut TVMArrayHandle, - ) -> ::std::os::raw::c_int; -} -extern "C" { - /// \brief Free the TVM Array. - /// \param handle The array handle to be freed. - /// \return 0 when success, -1 when failure happens - pub fn TVMArrayFree(handle: TVMArrayHandle) -> ::std::os::raw::c_int; -} -extern "C" { - /// \brief Copy array data from CPU byte array. - /// \param handle The array handle. - /// \param data the data pointer - /// \param nbytes The number of bytes to copy. - /// \return 0 when success, -1 when failure happens - pub fn TVMArrayCopyFromBytes( - handle: TVMArrayHandle, - data: *mut ::std::os::raw::c_void, - nbytes: usize, - ) -> ::std::os::raw::c_int; -} -extern "C" { - /// \brief Copy array data to CPU byte array. - /// \param handle The array handle. - /// \param data the data pointer - /// \param nbytes The number of bytes to copy. - /// \return 0 when success, -1 when failure happens - pub fn TVMArrayCopyToBytes( - handle: TVMArrayHandle, - data: *mut ::std::os::raw::c_void, - nbytes: usize, - ) -> ::std::os::raw::c_int; -} -extern "C" { - /// \brief Copy the array, both from and to must be valid during the copy. - /// \param from The array to be copied from. - /// \param to The target space. - /// \param stream The stream where the copy happens, can be NULL. - /// \return 0 when success, -1 when failure happens - pub fn TVMArrayCopyFromTo( - from: TVMArrayHandle, - to: TVMArrayHandle, - stream: TVMStreamHandle, - ) -> ::std::os::raw::c_int; -} -extern "C" { - /// \brief Produce an array from the DLManagedTensor that shares data memory - /// with the DLManagedTensor. - /// \param from The source DLManagedTensor. - /// \param out The output array handle. - /// \return 0 when success, -1 when failure happens - pub fn TVMArrayFromDLPack( - from: *mut DLManagedTensor, - out: *mut TVMArrayHandle, - ) -> ::std::os::raw::c_int; -} -extern "C" { - /// \brief Produce a DLMangedTensor from the array that shares data memory with - /// the array. - /// \param from The source array. - /// \param out The DLManagedTensor handle. - /// \return 0 when success, -1 when failure happens - pub fn TVMArrayToDLPack( - from: TVMArrayHandle, - out: *mut *mut DLManagedTensor, - ) -> ::std::os::raw::c_int; -} -extern "C" { - /// \brief Delete (free) a DLManagedTensor's data. - /// \param dltensor Pointer to the DLManagedTensor. - pub fn TVMDLManagedTensorCallDeleter(dltensor: *mut DLManagedTensor); -} -extern "C" { - /// \brief Create a new runtime stream. - /// - /// \param device_type The device type of context - /// \param device_id The device id of context - /// \param out The new stream handle - /// \return 0 when success, -1 when failure happens - pub fn TVMStreamCreate( - device_type: ::std::os::raw::c_int, - device_id: ::std::os::raw::c_int, - out: *mut TVMStreamHandle, - ) -> ::std::os::raw::c_int; -} -extern "C" { - /// \brief Free a created stream handle. - /// - /// \param device_type The device type of context - /// \param device_id The device id of context - /// \param stream The stream to be freed - /// \return 0 when success, -1 when failure happens - pub fn TVMStreamFree( - device_type: ::std::os::raw::c_int, - device_id: ::std::os::raw::c_int, - stream: TVMStreamHandle, - ) -> ::std::os::raw::c_int; -} -extern "C" { - /// \brief Set the runtime stream of current thread to be stream. - /// The subsequent calls to the same device_type - /// will use the setted stream handle. - /// The specific type of stream is runtime device dependent. - /// - /// \param device_type The device type of context - /// \param device_id The device id of context. - /// \param handle The stream handle. - /// \return 0 when success, -1 when failure happens - pub fn TVMSetStream( - device_type: ::std::os::raw::c_int, - device_id: ::std::os::raw::c_int, - handle: TVMStreamHandle, - ) -> ::std::os::raw::c_int; -} -extern "C" { - /// \brief Wait until all computations on stream completes. - /// - /// \param device_type The device type of context - /// \param device_id The device id of context. - /// \param stream The stream to be synchronized. - /// \return 0 when success, -1 when failure happens - pub fn TVMSynchronize( - device_type: ::std::os::raw::c_int, - device_id: ::std::os::raw::c_int, - stream: TVMStreamHandle, - ) -> ::std::os::raw::c_int; -} -extern "C" { - /// \brief Synchronize two streams of execution. - /// - /// \param device_type The device type of context - /// \param device_id The device id of context - /// \param src The source stream to synchronize. - /// \param dst The destination stream to synchronize. - /// \return 0 when success, -1 when failure happens - pub fn TVMStreamStreamSynchronize( - device_type: ::std::os::raw::c_int, - device_id: ::std::os::raw::c_int, - src: TVMStreamHandle, - dst: TVMStreamHandle, - ) -> ::std::os::raw::c_int; -} -extern "C" { - /// \brief Backend function for modules to get function - /// from its environment mod_node (its imports and global function). - /// The user do should not call TVMFuncFree on func. - /// - /// \param mod_node The module handle. - /// \param func_name The name of the function. - /// \param out The result function. - /// \return 0 when no error is thrown, -1 when failure happens - pub fn TVMBackendGetFuncFromEnv( - mod_node: *mut ::std::os::raw::c_void, - func_name: *const ::std::os::raw::c_char, - out: *mut TVMFunctionHandle, - ) -> ::std::os::raw::c_int; -} -extern "C" { - /// \brief Backend function to register system-wide library symbol. - /// - /// \param name The name of the symbol - /// \param ptr The symbol address. - /// \return 0 when no error is thrown, -1 when failure happens - pub fn TVMBackendRegisterSystemLibSymbol( - name: *const ::std::os::raw::c_char, - ptr: *mut ::std::os::raw::c_void, - ) -> ::std::os::raw::c_int; -} -extern "C" { - /// \brief Backend function to allocate temporal workspace. - /// - /// \note The result allocate spaced is ensured to be aligned to kTempAllocaAlignment. - /// - /// \param nbytes The size of the space requested. - /// \param device_type The device type which the space will be allocated. - /// \param device_id The device id which the space will be allocated. - /// \param dtype_code_hint The type code of the array elements. Only used in - /// certain backends such as OpenGL. - /// \param dtype_bits_hint The type bits of the array elements. Only used in - /// certain backends such as OpenGL. - /// \return nullptr when error is thrown, a valid ptr if success - pub fn TVMBackendAllocWorkspace( - device_type: ::std::os::raw::c_int, - device_id: ::std::os::raw::c_int, - nbytes: u64, - dtype_code_hint: ::std::os::raw::c_int, - dtype_bits_hint: ::std::os::raw::c_int, - ) -> *mut ::std::os::raw::c_void; -} -extern "C" { - /// \brief Backend function to free temporal workspace. - /// - /// \param ptr The result allocated space pointer. - /// \param device_type The device type which the space will be allocated. - /// \param device_id The device id which the space will be allocated. - /// \return 0 when no error is thrown, -1 when failure happens - /// - /// \sa TVMBackendAllocWorkspace - pub fn TVMBackendFreeWorkspace( - device_type: ::std::os::raw::c_int, - device_id: ::std::os::raw::c_int, - ptr: *mut ::std::os::raw::c_void, - ) -> ::std::os::raw::c_int; -} -/// \brief Environment for TVM parallel task. -#[repr(C)] -#[derive(Debug, Copy, Clone)] -pub struct TVMParallelGroupEnv { - /// \brief Auxiliary used for synchronization - pub sync_handle: *mut ::std::os::raw::c_void, - /// \brief total amount of task - pub num_task: i32, -} -/// \brief The callback function to execute a parallel lambda -/// \param task_id the task id of the function. -/// \param penv The parallel environment backs the execution. -/// \param cdata The supporting closure data. -pub type FTVMParallelLambda = ::std::option::Option< - unsafe extern "C" fn( - task_id: ::std::os::raw::c_int, - penv: *mut TVMParallelGroupEnv, - cdata: *mut ::std::os::raw::c_void, - ) -> ::std::os::raw::c_int, ->; -extern "C" { - /// \brief Backend function for running parallel jobs. - /// - /// \param flambda The parallel function to be launched. - /// \param cdata The closure data. - /// \param num_task Number of tasks to launch, can be 0, means launch - /// with all available threads. - /// - /// \return 0 when no error is thrown, -1 when failure happens - pub fn TVMBackendParallelLaunch( - flambda: FTVMParallelLambda, - cdata: *mut ::std::os::raw::c_void, - num_task: ::std::os::raw::c_int, - ) -> ::std::os::raw::c_int; -} -extern "C" { - /// \brief BSP barrrier between parallel threads - /// \param task_id the task id of the function. - /// \param penv The parallel environment backs the execution. - /// \return 0 when no error is thrown, -1 when failure happens - pub fn TVMBackendParallelBarrier( - task_id: ::std::os::raw::c_int, - penv: *mut TVMParallelGroupEnv, - ) -> ::std::os::raw::c_int; -} -extern "C" { - /// \brief Simple static initialization function. - /// Run f once and set handle to be not null. - /// This function is mainly used for test purpose. - /// - /// \param handle An global address to indicate f - /// \param f The function to be ran - /// \param cdata The closure data to pass to the function. - /// \param nbytes Number of bytes in the closure data. - /// \return 0 when no error is thrown, -1 when failure happens - pub fn TVMBackendRunOnce( - handle: *mut *mut ::std::os::raw::c_void, - f: ::std::option::Option< - unsafe extern "C" fn(arg1: *mut ::std::os::raw::c_void) -> ::std::os::raw::c_int, - >, - cdata: *mut ::std::os::raw::c_void, - nbytes: ::std::os::raw::c_int, - ) -> ::std::os::raw::c_int; -} diff --git a/rust/common/src/errors.rs b/rust/common/src/errors.rs index a81fab9f8c8f..ad72f36433c0 100644 --- a/rust/common/src/errors.rs +++ b/rust/common/src/errors.rs @@ -1,15 +1,79 @@ -//! Error types for `TVMArgValue` and `TVMRetValue` conversions. +use std::fmt; -error_chain! { - errors { - TryFromTVMArgValueError(expected: String, actual: String) { - description("mismatched types while converting from TVMArgValue") - display("expected `{}` but given `{}`", expected, actual) +static TYPE_CODE_STRS: [&str; 15] = [ + "int", + "uint", + "float", + "handle", + "null", + "TVMType", + "TVMContext", + "ArrayHandle", + "NodeHandle", + "ModuleHandle", + "FuncHandle", + "str", + "bytes", + "NDArrayContainer", + "ExtBegin", +]; + +#[derive(Debug, Fail)] +pub struct ValueDowncastError { + actual_type_code: i64, + expected_type_code: i64, +} + +impl ValueDowncastError { + pub fn new(actual_type_code: i64, expected_type_code: i64) -> Self { + Self { + actual_type_code, + expected_type_code, } + } +} - TryFromTVMRetValueError(expected: String, actual: String) { - description("mismatched types while downcasting TVMRetValue") - display("invalid downcast: expected `{}` but given `{}`", expected, actual) +impl fmt::Display for ValueDowncastError { + fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + write!( + formatter, + "Could not downcast TVMValue: expected `{}` but was {}", + TYPE_CODE_STRS[self.actual_type_code as usize], + TYPE_CODE_STRS[self.expected_type_code as usize] + ) + } +} + +#[derive(Debug, Fail)] +#[fail(display = "Function call `{}` returned error: {}", context, message)] +pub struct FuncCallError { + context: String, + message: String, +} + +impl FuncCallError { + pub fn get_with_context(context: String) -> Self { + Self { + context, + message: unsafe { std::ffi::CStr::from_ptr(crate::ffi::TVMGetLastError()) } + .to_str() + .expect("double fault") + .to_owned(), } } } + +// error_chain! { +// errors { +// TryFromTVMRetValueError(expected_type: String, actual_type_code: i64) { +// description("mismatched types while downcasting TVMRetValue") +// display("invalid downcast: expected `{}` but was `{}`", +// expected_type, type_code_to_string(actual_type_code)) +// } +// } +// foreign_links { +// IntoString(std::ffi::IntoStringError); +// ParseInt(std::num::ParseIntError); +// Utf8(std::str::Utf8Error); +// } +// } diff --git a/rust/common/src/lib.rs b/rust/common/src/lib.rs index ad4c4f23579e..966655e802f8 100644 --- a/rust/common/src/lib.rs +++ b/rust/common/src/lib.rs @@ -1,39 +1,29 @@ //! This crate contains the refactored basic components required //! for `runtime` and `frontend` TVM crates. -#![crate_name = "tvm_common"] -#![recursion_limit = "1024"] -#![allow(non_camel_case_types, unused_imports)] -#![feature(box_syntax, try_from)] +#![feature(box_syntax, trait_alias)] #[macro_use] -extern crate error_chain; +extern crate failure; /// Unified ffi module for both runtime and frontend crates. pub mod ffi { #![allow(non_camel_case_types, non_snake_case, non_upper_case_globals, unused)] - #[cfg(feature = "frontend")] - pub extern crate tvm_sys as ts; + use std::os::raw::{c_char, c_int, c_void}; - #[cfg(feature = "runtime")] - pub mod runtime { - use std::os::raw::{c_char, c_int, c_void}; + include!(concat!(env!("CARGO_MANIFEST_DIR"), "/src/c_runtime_api.rs")); - include!(concat!(env!("CARGO_MANIFEST_DIR"), "/src/c_runtime_api.rs")); - - pub type BackendPackedCFunc = extern "C" fn( - args: *const TVMValue, - type_codes: *const c_int, - num_args: c_int, - ) -> c_int; - } + pub type BackendPackedCFunc = + extern "C" fn(args: *const TVMValue, type_codes: *const c_int, num_args: c_int) -> c_int; } +pub mod array; pub mod errors; -pub mod ty; +#[macro_use] +pub mod packed_func; pub mod value; pub use errors::*; -pub use ty::TVMTypeCode; -pub use value::{TVMArgValue, TVMRetValue, TVMValue}; +pub use ffi::{TVMContext, TVMType}; +pub use packed_func::{TVMArgValue, TVMRetValue}; diff --git a/rust/common/src/packed_func.rs b/rust/common/src/packed_func.rs new file mode 100644 index 000000000000..a564fe656415 --- /dev/null +++ b/rust/common/src/packed_func.rs @@ -0,0 +1,312 @@ +use std::{any::Any, convert::TryFrom, marker::PhantomData, os::raw::c_void}; + +use failure::Error; + +pub use crate::ffi::TVMValue; +use crate::ffi::*; + +pub trait PackedFunc = + Fn(&[TVMArgValue]) -> Result + Send + Sync; + +/// Calls a packed function and returns a `TVMRetValue`. +/// +/// # Example +/// +/// `call_packed!(my_tvm_func, &mut arg1, &mut arg2)` +#[macro_export] +macro_rules! call_packed { + ($fn:expr, $($args:expr),+) => { + $fn(&[$($args.into(),)+]) + }; + ($fn:expr) => { + $fn(&Vec::new()) + }; +} + +/// A borrowed TVMPODValue. Can be constructed using `into()` but the preferred way +/// to obtain a `TVMArgValue` is automatically via `call_packed!`. +#[derive(Clone, Copy)] +pub struct TVMArgValue<'a> { + pub _lifetime: PhantomData<&'a ()>, + pub value: TVMValue, + pub type_code: i64, +} + +impl<'a> TVMArgValue<'a> { + pub fn new(value: TVMValue, type_code: i64) -> Self { + TVMArgValue { + _lifetime: PhantomData, + value: value, + type_code: type_code, + } + } +} + +#[macro_export] +macro_rules! ensure_type { + ($val:ident, $expected_type_code:expr) => { + ensure!( + $val.type_code == $expected_type_code as i64, + $crate::errors::ValueDowncastError::new( + $val.type_code as i64, + $expected_type_code as i64 + ) + ); + }; +} + +/// Creates a conversion to a `TVMArgValue` for a primitive type and DLDataTypeCode. +macro_rules! impl_prim_tvm_arg { + ($type_code:ident, $field:ident, $field_type:ty, [ $( $type:ty ),+ ] ) => { + $( + impl From<$type> for TVMArgValue<'static> { + fn from(val: $type) -> Self { + TVMArgValue { + value: TVMValue { $field: val as $field_type }, + type_code: $type_code as i64, + _lifetime: PhantomData, + } + } + } + impl<'a> From<&'a $type> for TVMArgValue<'a> { + fn from(val: &'a $type) -> Self { + TVMArgValue { + value: TVMValue { + $field: val.to_owned() as $field_type, + }, + type_code: $type_code as i64, + _lifetime: PhantomData, + } + } + } + impl<'a> TryFrom> for $type { + type Error = Error; + fn try_from(val: TVMArgValue<'a>) -> Result { + ensure_type!(val, $type_code); + Ok(unsafe { val.value.$field as $type }) + } + } + + impl<'a> TryFrom<&TVMArgValue<'a>> for $type { + type Error = Error; + fn try_from(val: &TVMArgValue<'a>) -> Result { + ensure_type!(val, $type_code); + Ok(unsafe { val.value.$field as $type }) + } + } + )+ + }; +} + +impl_prim_tvm_arg!(DLDataTypeCode_kDLFloat, v_float64, f64, [f32, f64]); +impl_prim_tvm_arg!( + DLDataTypeCode_kDLInt, + v_int64, + i64, + [i8, i16, i32, i64, isize] +); +impl_prim_tvm_arg!( + DLDataTypeCode_kDLUInt, + v_int64, + i64, + [u8, u16, u32, u64, usize] +); + +#[cfg(feature = "bindings")] +// only allow this in bindings because pure-rust can't take ownership of leaked CString +impl<'a> From<&String> for TVMArgValue<'a> { + fn from(string: &String) -> Self { + TVMArgValue { + value: TVMValue { + v_str: std::ffi::CString::new(string.clone()).unwrap().into_raw(), + }, + type_code: TVMTypeCode_kStr as i64, + _lifetime: PhantomData, + } + } +} + +impl<'a> From<&std::ffi::CString> for TVMArgValue<'a> { + fn from(string: &std::ffi::CString) -> Self { + TVMArgValue { + value: TVMValue { + v_str: string.as_ptr(), + }, + type_code: TVMTypeCode_kStr as i64, + _lifetime: PhantomData, + } + } +} + +impl<'a> TryFrom> for &str { + type Error = Error; + fn try_from(arg: TVMArgValue<'a>) -> Result { + ensure_type!(arg, TVMTypeCode_kStr); + Ok(unsafe { std::ffi::CStr::from_ptr(arg.value.v_handle as *const i8) }.to_str()?) + } +} + +impl<'a> TryFrom<&TVMArgValue<'a>> for &str { + type Error = Error; + fn try_from(arg: &TVMArgValue<'a>) -> Result { + ensure_type!(arg, TVMTypeCode_kStr); + Ok(unsafe { std::ffi::CStr::from_ptr(arg.value.v_handle as *const i8) }.to_str()?) + } +} + +/// Creates a conversion to a `TVMArgValue` for an object handle. +impl<'a, T> From<*const T> for TVMArgValue<'a> { + fn from(ptr: *const T) -> Self { + TVMArgValue { + value: TVMValue { + v_handle: ptr as *mut T as *mut c_void, + }, + type_code: TVMTypeCode_kArrayHandle as i64, + _lifetime: PhantomData, + } + } +} + +/// Creates a conversion to a `TVMArgValue` for a mutable object handle. +impl<'a, T> From<*mut T> for TVMArgValue<'a> { + fn from(ptr: *mut T) -> Self { + TVMArgValue { + value: TVMValue { + v_handle: ptr as *mut c_void, + }, + type_code: TVMTypeCode_kHandle as i64, + _lifetime: PhantomData, + } + } +} + +impl<'a> From<&'a mut DLTensor> for TVMArgValue<'a> { + fn from(arr: &'a mut DLTensor) -> Self { + TVMArgValue { + value: TVMValue { + v_handle: arr as *mut _ as *mut c_void, + }, + type_code: TVMTypeCode_kArrayHandle as i64, + _lifetime: PhantomData, + } + } +} + +impl<'a> From<&'a DLTensor> for TVMArgValue<'a> { + fn from(arr: &'a DLTensor) -> Self { + TVMArgValue { + value: TVMValue { + v_handle: arr as *const _ as *mut DLTensor as *mut c_void, + }, + type_code: TVMTypeCode_kArrayHandle as i64, + _lifetime: PhantomData, + } + } +} + +impl<'a, 'v> TryFrom<&'a TVMArgValue<'v>> for TVMType { + type Error = Error; + fn try_from(arg: &'a TVMArgValue<'v>) -> Result { + ensure_type!(arg, TVMTypeCode_kTVMType); + Ok(unsafe { arg.value.v_type.into() }) + } +} + +/// An owned TVMPODValue. Can be converted from a variety of primitive and object types. +/// Can be downcasted using `try_from` if it contains the desired type. +/// +/// # Example +/// +/// ``` +/// let a = 42u32; +/// let b: i64 = TVMRetValue::from(a).try_into().unwrap(); +/// +/// let s = "hello, world!"; +/// let t: TVMRetValue = s.into(); +/// assert_eq!(String::try_from(t).unwrap(), s); +/// ``` +pub struct TVMRetValue { + pub value: TVMValue, + pub box_value: Box, + pub type_code: i64, +} + +impl TVMRetValue { + pub fn from_tvm_value(value: TVMValue, type_code: i64) -> Self { + Self { + value, + type_code, + box_value: box (), + } + } + + pub fn into_tvm_value(self) -> (TVMValue, TVMTypeCode) { + (self.value, self.type_code as TVMTypeCode) + } +} + +impl Default for TVMRetValue { + fn default() -> Self { + TVMRetValue { + value: TVMValue { v_int64: 0 as i64 }, + type_code: 0, + box_value: box (), + } + } +} + +macro_rules! impl_pod_ret_value { + ($code:expr, [ $( $ty:ty ),+ ] ) => { + $( + impl From<$ty> for TVMRetValue { + fn from(val: $ty) -> Self { + Self { + value: val.into(), + type_code: $code as i64, + box_value: box (), + } + } + } + + impl TryFrom for $ty { + type Error = Error; + fn try_from(ret: TVMRetValue) -> Result<$ty, Self::Error> { + ensure_type!(ret, $code); + Ok(ret.value.into()) + } + } + )+ + }; +} + +impl_pod_ret_value!(DLDataTypeCode_kDLInt, [i8, i16, i32, i64, isize]); +impl_pod_ret_value!(DLDataTypeCode_kDLUInt, [u8, u16, u32, u64, usize]); +impl_pod_ret_value!(DLDataTypeCode_kDLFloat, [f32, f64]); +impl_pod_ret_value!(TVMTypeCode_kTVMType, [TVMType]); +impl_pod_ret_value!(TVMTypeCode_kTVMContext, [TVMContext]); + +impl TryFrom for String { + type Error = Error; + fn try_from(ret: TVMRetValue) -> Result { + ensure_type!(ret, TVMTypeCode_kStr); + let cs = unsafe { std::ffi::CString::from_raw(ret.value.v_handle as *mut i8) }; + let ret_str = cs.clone().into_string(); + if cfg!(feature = "bindings") { + std::mem::forget(cs); // TVM C++ takes ownership of CString. (@see TVMFuncCall) + } + Ok(ret_str?) + } +} + +impl From for TVMRetValue { + fn from(s: String) -> Self { + let cs = std::ffi::CString::new(s).unwrap(); + Self { + value: TVMValue { + v_str: cs.into_raw() as *mut i8, + }, + box_value: box (), + type_code: TVMTypeCode_kStr as i64, + } + } +} diff --git a/rust/common/src/ty.rs b/rust/common/src/ty.rs index 126bcd44527e..e69de29bb2d1 100644 --- a/rust/common/src/ty.rs +++ b/rust/common/src/ty.rs @@ -1,144 +0,0 @@ -//! This module containes `TVMTypeCode` and `TVMType` with some conversion methods. -//! -//! # Example -//! -//! ``` -//! let dtype = TVMType::from("float"); -//! println!("dtype is: {}", dtype); -//! ``` - -use std::{ - ffi::{CStr, CString}, - fmt::{self, Display, Formatter}, -}; - -/// TVM type codes. -#[repr(u32)] -#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] -pub enum TVMTypeCode { - kDLInt = 0, - kDLUInt = 1, - kDLFloat = 2, - kHandle = 3, - kNull = 4, - kTVMType = 5, - kTVMContext = 6, - kArrayHandle = 7, - kNodeHandle = 8, - kModuleHandle = 9, - kFuncHandle = 10, - kStr = 11, - kBytes = 12, - kNDArrayContainer = 13, -} - -impl Default for TVMTypeCode { - fn default() -> Self { - TVMTypeCode::kDLInt - } -} - -impl From for i64 { - fn from(arg: TVMTypeCode) -> i64 { - match arg { - TVMTypeCode::kDLInt => 0, - TVMTypeCode::kDLUInt => 1, - TVMTypeCode::kDLFloat => 2, - TVMTypeCode::kHandle => 3, - TVMTypeCode::kNull => 4, - TVMTypeCode::kTVMType => 5, - TVMTypeCode::kTVMContext => 6, - TVMTypeCode::kArrayHandle => 7, - TVMTypeCode::kNodeHandle => 8, - TVMTypeCode::kModuleHandle => 9, - TVMTypeCode::kFuncHandle => 10, - TVMTypeCode::kStr => 11, - TVMTypeCode::kBytes => 12, - TVMTypeCode::kNDArrayContainer => 13, - } - } -} - -impl Into for i64 { - fn into(self) -> TVMTypeCode { - match self { - 0 => TVMTypeCode::kDLInt, - 1 => TVMTypeCode::kDLUInt, - 2 => TVMTypeCode::kDLFloat, - 3 => TVMTypeCode::kHandle, - 4 => TVMTypeCode::kNull, - 5 => TVMTypeCode::kTVMType, - 6 => TVMTypeCode::kTVMContext, - 7 => TVMTypeCode::kArrayHandle, - 8 => TVMTypeCode::kNodeHandle, - 9 => TVMTypeCode::kModuleHandle, - 10 => TVMTypeCode::kFuncHandle, - 11 => TVMTypeCode::kStr, - 12 => TVMTypeCode::kBytes, - 13 => TVMTypeCode::kNDArrayContainer, - _ => unreachable!(), - } - } -} - -impl Display for TVMTypeCode { - fn fmt(&self, f: &mut Formatter) -> fmt::Result { - write!( - f, - "{}", - match self { - TVMTypeCode::kDLInt => "int", - TVMTypeCode::kDLUInt => "uint", - TVMTypeCode::kDLFloat => "float", - TVMTypeCode::kHandle => "handle", - TVMTypeCode::kNull => "null", - TVMTypeCode::kTVMType => "TVM type", - TVMTypeCode::kTVMContext => "TVM context", - TVMTypeCode::kArrayHandle => "Array handle", - TVMTypeCode::kNodeHandle => "Node handle", - TVMTypeCode::kModuleHandle => "Module handle", - TVMTypeCode::kFuncHandle => "Function handle", - TVMTypeCode::kStr => "string", - TVMTypeCode::kBytes => "bytes", - TVMTypeCode::kNDArrayContainer => "ndarray container", - } - ) - } -} - -macro_rules! impl_prim_type { - ($type:ty, $variant:ident) => { - impl<'a> From<&'a $type> for TVMTypeCode { - fn from(_arg: &$type) -> Self { - TVMTypeCode::$variant - } - } - - impl<'a> From<&'a mut $type> for TVMTypeCode { - fn from(_arg: &mut $type) -> Self { - TVMTypeCode::$variant - } - } - }; -} - -impl_prim_type!(usize, kDLInt); -impl_prim_type!(i64, kDLInt); -impl_prim_type!(i32, kDLInt); -impl_prim_type!(i16, kDLInt); -impl_prim_type!(i8, kDLInt); - -impl_prim_type!(u64, kDLUInt); -impl_prim_type!(u32, kDLUInt); -impl_prim_type!(u16, kDLUInt); -impl_prim_type!(u8, kDLUInt); - -impl_prim_type!(f64, kDLFloat); -impl_prim_type!(f32, kDLFloat); - -impl_prim_type!(str, kStr); -impl_prim_type!(CStr, kStr); -impl_prim_type!(String, kStr); -impl_prim_type!(CString, kStr); - -impl_prim_type!([u8], kBytes); diff --git a/rust/common/src/value.rs b/rust/common/src/value.rs index 6da8b27e8660..c7c040b0060e 100644 --- a/rust/common/src/value.rs +++ b/rust/common/src/value.rs @@ -1,559 +1,139 @@ -//! This module provides the the wrapped `TVMValue`, `TVMArgValue` and `TVMRetValue` -//! required for using TVM functions. +use std::str::FromStr; -use std::{ - any::Any, - convert::TryFrom, - ffi::{CStr, CString}, - fmt::{self, Debug, Formatter}, - marker::PhantomData, - mem, - ops::Deref, - os::raw::{c_char, c_void}, -}; +use failure::Error; -#[cfg(feature = "runtime")] -use ffi::runtime::TVMValue as _TVMValue; +use crate::ffi::*; -#[cfg(feature = "frontend")] -use ffi::ts::TVMValue as _TVMValue; - -use errors::*; - -use ty::TVMTypeCode; - -/// Wrapped TVMValue type. -#[derive(Clone, Copy)] -pub struct TVMValue { - pub inner: _TVMValue, -} - -impl TVMValue { - /// Creates TVMValue from the raw part. - pub fn new(inner: _TVMValue) -> Self { - TVMValue { inner } - } - - pub(crate) fn into_raw(self) -> _TVMValue { - self.inner - } -} - -impl Debug for TVMValue { - fn fmt(&self, f: &mut Formatter) -> fmt::Result { - unsafe { - write!( - f, - "TVMValue: [v_int64: {:?}], [v_float64: {:?}], [v_handle: {:?}],\ - [v_str: {:?}]", - self.inner.v_int64, self.inner.v_float64, self.inner.v_handle, self.inner.v_str - ) +impl TVMType { + fn new(type_code: u8, bits: u8, lanes: u16) -> Self { + Self { + code: type_code, + bits, + lanes, } } } -impl Deref for TVMValue { - type Target = _TVMValue; - fn deref(&self) -> &Self::Target { - &self.inner - } -} - -macro_rules! impl_prim_val { - ($type:ty, $field:ident, $cast:ty) => { - impl From<$type> for TVMValue { - fn from(arg: $type) -> Self { - let inner = _TVMValue { - $field: arg as $cast, - }; - Self::new(inner) - } +/// Implements TVMType conversion from `&str` of general format `{dtype}{bits}x{lanes}` +/// such as "int32", "float32" or with lane "float32x1". +impl FromStr for TVMType { + type Err = Error; + fn from_str(type_str: &str) -> Result { + if type_str == "bool" { + return Ok(TVMType::new(1, 1, 1)); } - impl<'a> From<&'a $type> for TVMValue { - fn from(arg: &$type) -> Self { - let inner = _TVMValue { - $field: *arg as $cast, - }; - Self::new(inner) + let mut type_lanes = type_str.split("x"); + let typ = type_lanes.next().expect("Missing dtype"); + let lanes = type_lanes + .next() + .map(|l| ::from_str_radix(l, 10)) + .unwrap_or(Ok(1))?; + let (type_name, bits) = match typ.find(char::is_numeric) { + Some(idx) => { + let (name, bits_str) = typ.split_at(idx); + (name, u8::from_str_radix(bits_str, 10)?) } - } - - impl<'a> From<&'a mut $type> for TVMValue { - fn from(arg: &mut $type) -> Self { - let inner = _TVMValue { - $field: *arg as $cast, - }; - Self::new(inner) - } - } - - impl TryFrom for $type { - type Error = Error; - fn try_from(val: TVMValue) -> Result { - Ok(unsafe { val.inner.$field as $type }) - } - } - - impl<'a> TryFrom<&'a TVMValue> for $type { - type Error = Error; - fn try_from(val: &TVMValue) -> Result { - Ok(unsafe { val.into_raw().$field as $type }) - } - } - - impl<'a> TryFrom<&'a mut TVMValue> for $type { - type Error = Error; - fn try_from(val: &mut TVMValue) -> Result { - Ok(unsafe { val.into_raw().$field as $type }) - } - } - }; -} - -impl_prim_val!(isize, v_int64, i64); -impl_prim_val!(i64, v_int64, i64); -impl_prim_val!(i32, v_int64, i64); -impl_prim_val!(i16, v_int64, i64); -impl_prim_val!(i8, v_int64, i64); -impl_prim_val!(usize, v_int64, i64); -impl_prim_val!(u64, v_int64, i64); -impl_prim_val!(u32, v_int64, i64); -impl_prim_val!(u16, v_int64, i64); -impl_prim_val!(u8, v_int64, i64); - -impl_prim_val!(f64, v_float64, f64); -impl_prim_val!(f32, v_float64, f64); - -impl<'a> From<&'a str> for TVMValue { - fn from(arg: &str) -> TVMValue { - let arg = CString::new(arg).unwrap(); - let inner = _TVMValue { - v_str: arg.as_ptr() as *const c_char, + None => (typ, 32), }; - mem::forget(arg); - Self::new(inner) - } -} - -impl<'a> From<&'a String> for TVMValue { - fn from(arg: &String) -> TVMValue { - let arg = CString::new(arg.as_bytes()).unwrap(); - let inner = _TVMValue { - v_str: arg.as_ptr() as *const c_char, - }; - mem::forget(arg); - Self::new(inner) - } -} - -impl<'a> From<&'a CString> for TVMValue { - fn from(arg: &CString) -> TVMValue { - let arg = arg.to_owned(); - let inner = _TVMValue { - v_str: arg.as_ptr() as *const c_char, - }; - mem::forget(arg); - Self::new(inner) - } -} -impl<'a> From<&'a [u8]> for TVMValue { - fn from(arg: &[u8]) -> TVMValue { - let arg = arg.to_owned(); - let inner = _TVMValue { - v_handle: &arg as *const _ as *mut c_void, + let type_code = match type_name { + "int" => 0, + "uint" => 1, + "float" => 2, + "handle" => 3, + _ => return Err(format_err!("Unknown type {}", type_name)), }; - mem::forget(arg); - Self::new(inner) - } -} - -/// Captures both `TVMValue` and `TVMTypeCode` needed for TVM function. -/// The preferred way to obtain a `TVMArgValue` is automatically via `call_packed!`. -/// or in the frontend crate, with `function::Builder`. Checkout the methods for conversions. -/// -/// ## Example -/// -/// ``` -/// let s = "hello".to_string(); -/// let arg = TVMArgValue::from(&s); -/// let tvm: String = arg.try_into().unwrap(); -/// assert_eq!(arg, s); -/// ``` -#[derive(Debug, Clone, Copy)] -pub struct TVMArgValue<'a> { - /// The wrapped TVMValue - pub value: TVMValue, - /// The matching type code. - pub type_code: TVMTypeCode, - /// This is only exposed to runtime and frontend crates and is not meant to be used directly. - pub lifetime: PhantomData<&'a ()>, -} - -impl<'a> TVMArgValue<'a> { - pub fn new(value: TVMValue, type_code: TVMTypeCode) -> Self { - TVMArgValue { - value: value, - type_code: type_code, - lifetime: PhantomData, - } - } -} - -impl<'a, 'b> TryFrom<&'b TVMArgValue<'a>> for i64 { - type Error = Error; - fn try_from(arg: &TVMArgValue<'a>) -> Result { - if (arg.type_code == TVMTypeCode::kDLInt) - | (arg.type_code == TVMTypeCode::kDLUInt) - | (arg.type_code == TVMTypeCode::kNull) - { - Ok(unsafe { arg.value.inner.v_int64 }) - } else { - bail!(ErrorKind::TryFromTVMArgValueError( - stringify!(i64).to_string(), - arg.type_code.to_string() - )) - } - } -} - -impl<'a, 'b> TryFrom<&'b TVMArgValue<'a>> for f64 { - type Error = Error; - fn try_from(arg: &TVMArgValue<'a>) -> Result { - if arg.type_code == TVMTypeCode::kDLFloat { - Ok(unsafe { arg.value.inner.v_float64 }) - } else { - bail!(ErrorKind::TryFromTVMArgValueError( - stringify!(f64).to_string(), - arg.type_code.to_string() - )) - } - } -} - -impl<'a, 'b> TryFrom<&'b TVMArgValue<'a>> for String { - type Error = Error; - fn try_from(arg: &TVMArgValue<'a>) -> Result { - if arg.type_code == TVMTypeCode::kStr { - let ret_str = unsafe { - match CStr::from_ptr(arg.value.inner.v_str).to_str() { - Ok(s) => s, - Err(_) => "Invalid UTF-8 message", - } - }; - Ok(ret_str.to_string()) - } else { - bail!(ErrorKind::TryFromTVMArgValueError( - stringify!(String).to_string(), - arg.type_code.to_string() - )) - } - } -} - -/// Main way to create a TVMArgValue from suported Rust values. -impl<'b, 'a: 'b, T: 'b + ?Sized> From<&'b T> for TVMArgValue<'a> -where - TVMValue: From<&'b T>, - TVMTypeCode: From<&'b T>, -{ - fn from(arg: &'b T) -> Self { - TVMArgValue::new(TVMValue::from(arg), TVMTypeCode::from(arg)) - } -} - -/// Creates a conversion to a `TVMArgValue` for an object handle. -impl<'a, T> From<*const T> for TVMArgValue<'a> { - fn from(ptr: *const T) -> Self { - let value = TVMValue::new(_TVMValue { - v_handle: ptr as *mut T as *mut c_void, - }); - TVMArgValue::new(value, TVMTypeCode::kArrayHandle) + Ok(TVMType::new(type_code, bits, lanes)) } } -/// Creates a conversion to a `TVMArgValue` for a mutable object handle. -impl<'a, T> From<*mut T> for TVMArgValue<'a> { - fn from(ptr: *mut T) -> Self { - let value = TVMValue::new(_TVMValue { - v_handle: ptr as *mut c_void, - }); - - TVMArgValue::new(value, TVMTypeCode::kHandle) - } -} - -/// An owned version of TVMPODValue. It can be converted from varieties of -/// primitive and object types. -/// It can be downcasted using `try_from` if it contains the desired type. -/// -/// # Example -/// -/// ``` -/// let a = 42u32; -/// let b: i64 = TVMRetValue::from(a).try_into().unwrap(); -/// -/// let s = "hello, world!"; -/// let t: TVMRetValue = s.into(); -/// assert_eq!(String::try_from(t).unwrap(), s); -/// ``` -pub struct TVMRetValue { - /// A primitive return value, if any. - pub prim_value: usize, - /// An object return value, if any. - pub box_value: Box, - pub type_code: TVMTypeCode, -} - -impl TVMRetValue { - fn new(prim_value: usize, box_value: Box, type_code: TVMTypeCode) -> Self { - Self { - prim_value, - box_value, - type_code, +impl std::fmt::Display for TVMType { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + if self.bits == 1 && self.lanes == 1 { + return write!(f, "bool"); } - } - - /// unsafe function to create `TVMRetValue` from `TVMValue` and - /// its matching `TVMTypeCode`. - pub unsafe fn from_tvm_value(value: TVMValue, type_code: TVMTypeCode) -> Self { - let value = value.into_raw(); - match type_code { - TVMTypeCode::kDLInt | TVMTypeCode::kDLUInt => { - Self::new(value.v_int64 as usize, box (), type_code) - } - TVMTypeCode::kDLFloat => Self::new(value.v_float64 as usize, box (), type_code), - TVMTypeCode::kHandle - | TVMTypeCode::kArrayHandle - | TVMTypeCode::kNodeHandle - | TVMTypeCode::kModuleHandle - | TVMTypeCode::kFuncHandle => { - Self::new(value.v_handle as usize, box value.v_handle, type_code) - } - TVMTypeCode::kStr | TVMTypeCode::kBytes => { - Self::new(value.v_str as usize, box (value.v_str), type_code) - } - _ => Self::new(0usize, box (), type_code), + let mut type_str = match self.code { + 0 => "int", + 1 => "uint", + 2 => "float", + 4 => "handle", + _ => "unknown", } - } + .to_string(); - /// Returns the underlying `TVMValue` and `TVMTypeCode`. - pub fn into_tvm_value(self) -> (TVMValue, TVMTypeCode) { - let val = match self.type_code { - TVMTypeCode::kDLInt | TVMTypeCode::kDLUInt => TVMValue::new(_TVMValue { - v_int64: self.prim_value as i64, - }), - TVMTypeCode::kDLFloat => TVMValue::new(_TVMValue { - v_float64: self.prim_value as f64, - }), - TVMTypeCode::kHandle - | TVMTypeCode::kArrayHandle - | TVMTypeCode::kNodeHandle - | TVMTypeCode::kModuleHandle - | TVMTypeCode::kFuncHandle - | TVMTypeCode::kNDArrayContainer => TVMValue::new(_TVMValue { - v_handle: self.prim_value as *const c_void as *mut c_void, - }), - TVMTypeCode::kStr | TVMTypeCode::kBytes => TVMValue::new(_TVMValue { - v_str: self.prim_value as *const c_char, - }), - _ => unreachable!(), - }; - (val, self.type_code) - } -} - -impl Default for TVMRetValue { - fn default() -> Self { - TVMRetValue { - prim_value: 0usize, - box_value: box (), - type_code: TVMTypeCode::default(), - } - } -} - -impl Clone for TVMRetValue { - fn clone(&self) -> Self { - match self.type_code { - TVMTypeCode::kDLInt | TVMTypeCode::kDLUInt | TVMTypeCode::kDLFloat => { - Self::new(self.prim_value.clone(), box (), self.type_code.clone()) - } - TVMTypeCode::kHandle - | TVMTypeCode::kArrayHandle - | TVMTypeCode::kNodeHandle - | TVMTypeCode::kModuleHandle - | TVMTypeCode::kFuncHandle - | TVMTypeCode::kNDArrayContainer => Self::new( - self.prim_value.clone(), - box (self.prim_value.clone() as *const c_void as *mut c_void), - self.type_code.clone(), - ), - TVMTypeCode::kStr | TVMTypeCode::kBytes => Self::new( - self.prim_value.clone(), - box (self.prim_value.clone() as *const c_char), - self.type_code.clone(), - ), - _ => unreachable!(), + type_str += &self.bits.to_string(); + if self.lanes > 1 { + type_str += &format!("x{}", self.lanes); } + f.write_str(&type_str) } } -impl Debug for TVMRetValue { - fn fmt(&self, f: &mut Formatter) -> fmt::Result { - write!( - f, - "prim_value: {:?}, box_value: {:?}, type_code: {:?}", - self.prim_value, self.prim_value as *const c_void as *mut c_void, self.type_code - ) - } -} - -macro_rules! impl_prim_ret_value { - ($type:ty, $code:expr) => { - impl From<$type> for TVMRetValue { - fn from(val: $type) -> Self { - TVMRetValue { - prim_value: val as usize, - box_value: box (), - type_code: $code, - } - } - } - - impl<'a> From<&'a $type> for TVMRetValue { - fn from(val: &$type) -> Self { - TVMRetValue { - prim_value: *val as usize, - box_value: box (), - type_code: $code, +macro_rules! impl_pod_tvm_value { + ($field:ident, $field_ty:ty, $( $ty:ty ),+) => { + $( + impl From<$ty> for TVMValue { + fn from(val: $ty) -> Self { + TVMValue { $field: val as $field_ty } } } - } - impl<'a> From<&'a mut $type> for TVMRetValue { - fn from(val: &mut $type) -> Self { - TVMRetValue { - prim_value: *val as usize, - box_value: box (), - type_code: $code, + impl From for $ty { + fn from(val: TVMValue) -> Self { + unsafe { val.$field as $ty } } } - } - - impl TryFrom for $type { - type Error = Error; - fn try_from(ret: TVMRetValue) -> Result<$type> { - if ret.type_code == $code { - Ok(ret.prim_value as $type) - } else { - bail!(ErrorKind::TryFromTVMRetValueError( - stringify!($type).to_string(), - ret.type_code.to_string(), - )) - } - } - } + )+ }; + ($field:ident, $ty:ty) => { + impl_pod_tvm_value!($field, $ty, $ty); + } } -impl_prim_ret_value!(i8, TVMTypeCode::kDLInt); -impl_prim_ret_value!(i16, TVMTypeCode::kDLInt); -impl_prim_ret_value!(i32, TVMTypeCode::kDLInt); -impl_prim_ret_value!(i64, TVMTypeCode::kDLInt); -impl_prim_ret_value!(isize, TVMTypeCode::kDLInt); - -impl_prim_ret_value!(u8, TVMTypeCode::kDLUInt); -impl_prim_ret_value!(u16, TVMTypeCode::kDLUInt); -impl_prim_ret_value!(u32, TVMTypeCode::kDLUInt); -impl_prim_ret_value!(u64, TVMTypeCode::kDLUInt); -impl_prim_ret_value!(usize, TVMTypeCode::kDLUInt); - -impl_prim_ret_value!(f32, TVMTypeCode::kDLFloat); -impl_prim_ret_value!(f64, TVMTypeCode::kDLFloat); +impl_pod_tvm_value!(v_int64, i64, i8, u8, i16, u16, i32, u32, i64, u64, isize, usize); +impl_pod_tvm_value!(v_float64, f64, f32, f64); +impl_pod_tvm_value!(v_type, TVMType); +impl_pod_tvm_value!(v_ctx, TVMContext); -macro_rules! impl_ptr_ret_value { - ($type:ty) => { - impl From<$type> for TVMRetValue { - fn from(ptr: $type) -> Self { - TVMRetValue { - prim_value: ptr as usize, - box_value: box (), - type_code: TVMTypeCode::kHandle, - } - } - } - - impl TryFrom for $type { - type Error = Error; - fn try_from(ret: TVMRetValue) -> Result<$type> { - if ret.type_code == TVMTypeCode::kHandle { - Ok(ret.prim_value as $type) - } else { - bail!(ErrorKind::TryFromTVMRetValueError( - stringify!($type).to_string(), - ret.type_code.to_string(), - )) - } +macro_rules! impl_tvm_context { + ( $( $dev_type:ident : [ $( $dev_name:ident ),+ ] ),+ ) => { + /// Creates a TVMContext from a string (e.g., "cpu", "gpu", "ext_dev") + impl FromStr for TVMContext { + type Err = Error; + fn from_str(type_str: &str) -> Result { + Ok(Self { + device_type: match type_str { + $( $( stringify!($dev_name) )|+ => $dev_type ),+, + _ => return Err(format_err!("device {} not supported", type_str).into()), + }, + device_id: 0, + }) } } - }; -} - -impl_ptr_ret_value!(*const c_void); -impl_ptr_ret_value!(*mut c_void); -impl From for TVMRetValue { - fn from(val: String) -> Self { - let pval = val.as_ptr() as *const c_char as usize; - let bval = box (val.as_ptr() as *const c_char); - mem::forget(val); - TVMRetValue::new(pval, bval, TVMTypeCode::kStr) - } -} - -impl TryFrom for String { - type Error = Error; - fn try_from(ret: TVMRetValue) -> Result { - // Note: simple downcast doesn't work for function call return values - let ret_str = unsafe { - match CStr::from_ptr(ret.prim_value as *const c_char).to_str() { - Ok(s) => s, - Err(_) => "Invalid UTF-8 message", - } - }; - - Ok(ret_str.to_string()) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use std::convert::TryInto; - - #[test] - fn numeric() { - macro_rules! arg_ret_tests { - ($v:expr; $($ty:ty),+) => {{ + impl TVMContext { + $( $( - let v = $v as $ty; - let b = TVMRetValue::from(&v); - let b: $ty = b.try_into().unwrap(); - assert_eq!(b, v); + pub fn $dev_name(device_id: usize) -> Self { + Self { + device_type: $dev_type, + device_id: device_id as i32, + } + } )+ - }}; + )+ } - - arg_ret_tests!(42; i8, i16, i32, i64, f32, f64); - } - - #[test] - fn string() { - let s = "hello".to_string(); - let tvm_arg: String = TVMRetValue::from(s.clone()).try_into().unwrap(); - assert_eq!(tvm_arg, s); - } + }; } + +impl_tvm_context!( + DLDeviceType_kDLCPU: [cpu, llvm, stackvm], + DLDeviceType_kDLGPU: [gpu, cuda, nvptx], + DLDeviceType_kDLOpenCL: [cl], + DLDeviceType_kDLMetal: [metal], + DLDeviceType_kDLVPI: [vpi], + DLDeviceType_kDLROCM: [rocm], + DLDeviceType_kDLExtDev: [ext_dev] +); diff --git a/rust/common/tvm-sys/Cargo.toml b/rust/common/tvm-sys/Cargo.toml deleted file mode 100644 index 117d174b4cbd..000000000000 --- a/rust/common/tvm-sys/Cargo.toml +++ /dev/null @@ -1,9 +0,0 @@ -[package] -name = "tvm-sys" -version = "0.1.0" -authors = ["TVM Contributors"] -license = "Apache-2.0" -description = "Raw C API" - -[build-dependencies] -bindgen = "0.37.4" diff --git a/rust/common/tvm-sys/build.rs b/rust/common/tvm-sys/build.rs deleted file mode 100644 index f842043a1d16..000000000000 --- a/rust/common/tvm-sys/build.rs +++ /dev/null @@ -1,25 +0,0 @@ -extern crate bindgen; - -use std::path::PathBuf; - -fn main() { - println!("cargo:rerun-if-env-changed=TVM_HOME"); - println!("cargo:rustc-link-lib=dylib=tvm_runtime"); - println!("cargo:rustc-link-search={}/build", env!("TVM_HOME")); - let bindings = bindgen::Builder::default() - .header(format!( - "{}/include/tvm/runtime/c_runtime_api.h", - env!("TVM_HOME") - )) - .clang_arg(format!("-I{}/3rdparty/dlpack/include/", env!("TVM_HOME"))) - .blacklist_type("max_align_t") // @see rust-bindgen#550 - .layout_tests(false) - .derive_partialeq(true) - .derive_eq(true) - .generate() - .expect("unable to generate bindings"); - - bindings - .write_to_file(PathBuf::from("src/bindgen.rs")) - .expect("can not write the bindings!"); -} diff --git a/rust/common/tvm-sys/src/lib.rs b/rust/common/tvm-sys/src/lib.rs deleted file mode 100644 index 15f1ea3a611c..000000000000 --- a/rust/common/tvm-sys/src/lib.rs +++ /dev/null @@ -1,9 +0,0 @@ -#![allow( - non_camel_case_types, - non_snake_case, - non_upper_case_globals, - dead_code, - improper_ctypes -)] - -include!("bindgen.rs"); diff --git a/rust/frontend/Cargo.toml b/rust/frontend/Cargo.toml index db261551e36f..eb1f5b8db021 100644 --- a/rust/frontend/Cargo.toml +++ b/rust/frontend/Cargo.toml @@ -15,11 +15,11 @@ name = "tvm_frontend" crate-type = ["dylib"] [dependencies] -error-chain = "0.12.0" +failure = "0.1.5" lazy_static = "1.1.0" ndarray = "0.12.1" num-traits = "0.2" -tvm-common = { version = "0.1.0", path = "../common/", features = ["frontend"] } +tvm-common = { version = "0.1.0", path = "../common/", features = ["bindings"] } [features] blas = ["ndarray/blas"] diff --git a/rust/frontend/examples/resnet/src/main.rs b/rust/frontend/examples/resnet/src/main.rs index 869a35b3a3a4..2ad3efa9082a 100644 --- a/rust/frontend/examples/resnet/src/main.rs +++ b/rust/frontend/examples/resnet/src/main.rs @@ -1,5 +1,3 @@ -#![feature(try_from)] - extern crate csv; extern crate image; extern crate ndarray; @@ -10,6 +8,7 @@ use std::{ convert::TryInto, fs::{self, File}, path::Path, + str::FromStr, }; use image::{FilterType, GenericImageView}; @@ -44,8 +43,12 @@ fn main() { // make arr shape as [1, 3, 224, 224] acceptable to resnet let arr = arr.insert_axis(Axis(0)); // create input tensor from rust's ndarray - let input = - NDArray::from_rust_ndarray(&arr, TVMContext::cpu(0), TVMType::from("float32")).unwrap(); + let input = NDArray::from_rust_ndarray( + &arr, + TVMContext::cpu(0), + TVMType::from_str("float32").unwrap(), + ) + .unwrap(); println!( "input size is {:?}", input.shape().expect("cannot get the input shape") @@ -59,7 +62,7 @@ fn main() { ))) .unwrap(); // get the global TVM graph runtime function - let runtime_create_fn = Function::get("tvm.graph_runtime.create", true).unwrap(); + let runtime_create_fn = Function::get("tvm.graph_runtime.create").unwrap(); let runtime_create_fn_ret = call_packed!( runtime_create_fn, &graph, @@ -85,14 +88,19 @@ fn main() { .get_function("set_input", false) .unwrap(); - call_packed!(set_input_fn, "data", &input).unwrap(); + let data_str = "data".to_string(); + call_packed!(set_input_fn, &data_str, &input).unwrap(); // get `run` function from runtime module let ref run_fn = graph_runtime_module.get_function("run", false).unwrap(); // execute the run function. Note that it has no argument call_packed!(run_fn,).unwrap(); // prepare to get the output let output_shape = &mut [1, 1000]; - let output = NDArray::empty(output_shape, TVMContext::cpu(0), TVMType::from("float32")); + let output = NDArray::empty( + output_shape, + TVMContext::cpu(0), + TVMType::from_str("float32").unwrap(), + ); // get the `get_output` function from runtime module let ref get_output_fn = graph_runtime_module .get_function("get_output", false) diff --git a/rust/frontend/src/bytearray.rs b/rust/frontend/src/bytearray.rs index 395f34c2428d..9274dba862da 100644 --- a/rust/frontend/src/bytearray.rs +++ b/rust/frontend/src/bytearray.rs @@ -3,9 +3,9 @@ //! //! For more detail, please see the example `resnet` in `examples` repository. -use std::os::raw::c_char; +use std::os::raw::{c_char, c_void}; -use crate::ts; +use tvm_common::{ffi, TVMArgValue}; /// A struct holding TVM byte-array. /// @@ -19,11 +19,11 @@ use crate::ts; /// ``` #[derive(Debug, Clone)] pub struct TVMByteArray { - pub(crate) inner: ts::TVMByteArray, + pub(crate) inner: ffi::TVMByteArray, } impl TVMByteArray { - pub(crate) fn new(barr: ts::TVMByteArray) -> TVMByteArray { + pub(crate) fn new(barr: ffi::TVMByteArray) -> TVMByteArray { TVMByteArray { inner: barr } } @@ -46,7 +46,7 @@ impl TVMByteArray { impl<'a> From<&'a Vec> for TVMByteArray { fn from(arg: &Vec) -> Self { - let barr = ts::TVMByteArray { + let barr = ffi::TVMByteArray { data: arg.as_ptr() as *const c_char, size: arg.len(), }; @@ -54,6 +54,18 @@ impl<'a> From<&'a Vec> for TVMByteArray { } } +impl<'a> From<&TVMByteArray> for TVMArgValue<'a> { + fn from(arr: &TVMByteArray) -> Self { + Self { + value: ffi::TVMValue { + v_handle: &arr.inner as *const ffi::TVMByteArray as *const c_void as *mut c_void, + }, + type_code: ffi::TVMTypeCode_kBytes as i64, + _lifetime: std::marker::PhantomData, + } + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/rust/frontend/src/context.rs b/rust/frontend/src/context.rs index 65e11d82e2d0..5d800a8b9644 100644 --- a/rust/frontend/src/context.rs +++ b/rust/frontend/src/context.rs @@ -18,12 +18,20 @@ //! ``` use std::{ + convert::TryInto, fmt::{self, Display, Formatter}, os::raw::c_void, ptr, }; -use crate::{function, ts, Result}; +use failure::Error; + +use tvm_common::{ + ffi::{self, TVMValue}, + TVMArgValue, +}; + +use crate::function; /// Device type can be from a supported device name. See the supported devices /// in [TVM](https://github.com/dmlc/tvm). @@ -45,35 +53,35 @@ impl Default for TVMDeviceType { } } -impl From for ts::DLDeviceType { +impl From for ffi::DLDeviceType { fn from(device_type: TVMDeviceType) -> Self { match device_type.0 { - 1 => ts::DLDeviceType_kDLCPU, - 2 => ts::DLDeviceType_kDLGPU, - 3 => ts::DLDeviceType_kDLCPUPinned, - 4 => ts::DLDeviceType_kDLOpenCL, - 7 => ts::DLDeviceType_kDLVulkan, - 8 => ts::DLDeviceType_kDLMetal, - 9 => ts::DLDeviceType_kDLVPI, - 10 => ts::DLDeviceType_kDLROCM, - 12 => ts::DLDeviceType_kDLExtDev, + 1 => ffi::DLDeviceType_kDLCPU, + 2 => ffi::DLDeviceType_kDLGPU, + 3 => ffi::DLDeviceType_kDLCPUPinned, + 4 => ffi::DLDeviceType_kDLOpenCL, + 7 => ffi::DLDeviceType_kDLVulkan, + 8 => ffi::DLDeviceType_kDLMetal, + 9 => ffi::DLDeviceType_kDLVPI, + 10 => ffi::DLDeviceType_kDLROCM, + 12 => ffi::DLDeviceType_kDLExtDev, _ => panic!("device type not found!"), } } } -impl From for TVMDeviceType { - fn from(device_type: ts::DLDeviceType) -> Self { +impl From for TVMDeviceType { + fn from(device_type: ffi::DLDeviceType) -> Self { match device_type { - ts::DLDeviceType_kDLCPU => TVMDeviceType(1), - ts::DLDeviceType_kDLGPU => TVMDeviceType(2), - ts::DLDeviceType_kDLCPUPinned => TVMDeviceType(3), - ts::DLDeviceType_kDLOpenCL => TVMDeviceType(4), - ts::DLDeviceType_kDLVulkan => TVMDeviceType(7), - ts::DLDeviceType_kDLMetal => TVMDeviceType(8), - ts::DLDeviceType_kDLVPI => TVMDeviceType(9), - ts::DLDeviceType_kDLROCM => TVMDeviceType(10), - ts::DLDeviceType_kDLExtDev => TVMDeviceType(12), + ffi::DLDeviceType_kDLCPU => TVMDeviceType(1), + ffi::DLDeviceType_kDLGPU => TVMDeviceType(2), + ffi::DLDeviceType_kDLCPUPinned => TVMDeviceType(3), + ffi::DLDeviceType_kDLOpenCL => TVMDeviceType(4), + ffi::DLDeviceType_kDLVulkan => TVMDeviceType(7), + ffi::DLDeviceType_kDLMetal => TVMDeviceType(8), + ffi::DLDeviceType_kDLVPI => TVMDeviceType(9), + ffi::DLDeviceType_kDLROCM => TVMDeviceType(10), + ffi::DLDeviceType_kDLExtDev => TVMDeviceType(12), _ => panic!("device type not found!"), } } @@ -117,6 +125,18 @@ impl<'a> From<&'a str> for TVMDeviceType { } } +impl<'a> From<&'a TVMDeviceType> for TVMArgValue<'a> { + fn from(dev_type: &'a TVMDeviceType) -> Self { + Self { + value: TVMValue { + v_int64: dev_type.0 as i64, + }, + type_code: ffi::DLDataTypeCode_kDLInt as i64, + _lifetime: std::marker::PhantomData, + } + } +} + /// Represents the underlying device context. Default is cpu. /// /// ## Examples @@ -138,15 +158,15 @@ pub struct TVMContext { /// Supported device types pub device_type: TVMDeviceType, /// Device id - pub device_id: usize, + pub device_id: i32, } impl TVMContext { /// Creates context from device type and id. - pub fn new(device_type: TVMDeviceType, device_id: usize) -> Self { + pub fn new(device_type: TVMDeviceType, device_id: i32) -> Self { TVMContext { - device_type: device_type, - device_id: device_id, + device_type, + device_id, } } } @@ -155,7 +175,7 @@ macro_rules! impl_ctxs { ($(($ctx:ident, $dldevt:expr));+) => { $( impl TVMContext { - pub fn $ctx(device_id: usize) -> Self { + pub fn $ctx(device_id: i32) -> Self { Self::new(TVMDeviceType($dldevt), device_id) } } @@ -185,20 +205,20 @@ impl<'a> From<&'a str> for TVMContext { impl TVMContext { /// Checks whether the context exists or not. pub fn exist(&self) -> bool { - let func = function::Function::get("_GetDeviceAttr", true /* is_global */) - .expect("API function always exists"); + let func = function::Function::get("_GetDeviceAttr").expect("API function always exists"); let dt = self.device_type.0 as usize; // `unwrap` is ok here because if there is any error, // if would occure inside `call_packed!` - let ret = call_packed!(func, &dt, &self.device_id, &0) + let ret: u64 = call_packed!(func, &dt, &self.device_id, &0) .unwrap() - .prim_value; + .try_into() + .unwrap(); ret != 0 } /// Synchronize the context stream. - pub fn sync(&self) -> Result<()> { - check_call!(ts::TVMSynchronize( + pub fn sync(&self) -> Result<(), Error> { + check_call!(ffi::TVMSynchronize( self.device_type.0 as i32, self.device_id as i32, ptr::null_mut() as *mut c_void @@ -212,16 +232,17 @@ macro_rules! impl_device_attrs { $( impl TVMContext { pub fn $attr_name(&self) -> usize { - let func = function::Function::get("_GetDeviceAttr", true /* is_global */) + let func = function::Function::get("_GetDeviceAttr") .expect("API function always exists"); let dt = self.device_type.0 as usize; // `unwrap` is ok here because if there is any error, // if would occur in function call. - let ret = function::Builder::from(func) - .args(&[dt, self.device_id, $attr_kind]) + function::Builder::from(func) + .args(&[dt, self.device_id as usize, $attr_kind]) .invoke() - .unwrap(); - ret.prim_value as usize + .unwrap() + .try_into() + .unwrap() } } )+ @@ -237,18 +258,18 @@ impl_device_attrs!((max_threads_per_block, 1); (multi_processor_count, 7); (max_thread_dimensions, 8)); -impl From for TVMContext { - fn from(ctx: ts::DLContext) -> Self { +impl From for TVMContext { + fn from(ctx: ffi::DLContext) -> Self { TVMContext { device_type: TVMDeviceType::from(ctx.device_type), - device_id: ctx.device_id as usize, + device_id: ctx.device_id, } } } -impl From for ts::DLContext { +impl From for ffi::DLContext { fn from(ctx: TVMContext) -> Self { - ts::DLContext { + ffi::DLContext { device_type: ctx.device_type.into(), device_id: ctx.device_id as i32, } diff --git a/rust/frontend/src/errors.rs b/rust/frontend/src/errors.rs index a10f83c4110d..96a70caf59c0 100644 --- a/rust/frontend/src/errors.rs +++ b/rust/frontend/src/errors.rs @@ -1,51 +1,26 @@ -//! This module implements TVM custom [`Error`], [`ErrorKind`] and [`Result`] types. +pub use failure::Error; -use std::{ffi, option}; +#[derive(Debug, Fail)] +#[fail(display = "Cannot convert from an empty array.")] +pub struct EmptyArrayError; -use crate::{common_errors, rust_ndarray}; - -error_chain! { - errors { - EmptyArray { - description("cannot convert from an empty array") - } - - NullHandle(name: String) { - description("null handle") - display("requested `{}` handle is null", name) - } - - FunctionNotFound { - description("function not found") - display("function was not set in `function::Builder`") - } - - TypeMismatch(expected: String, found: String) { - description("type mismatch!") - display("expected type `{}`, but found `{}`", expected, found) - } - - MissingShapeError { - description("ndarray `shape()` returns `None`") - display("called `Option::unwrap()` on a `None` value") - } - - AtMostOneReturn { - description("TVM functions accept at most one return value") - } +#[derive(Debug, Fail)] +#[fail(display = "Handle `{}` is null.", name)] +pub struct NullHandleError { + pub name: String, +} - } +#[derive(Debug, Fail)] +#[fail(display = "Function was not set in `function::Builder`")] +pub struct FunctionNotFoundError; - foreign_links { - ShapeError(rust_ndarray::ShapeError); - NulError(ffi::NulError); - IntoStringError(ffi::IntoStringError); - CommonError(common_errors::Error); - } +#[derive(Debug, Fail)] +#[fail(display = "Expected type `{}` but found `{}`", expected, actual)] +pub struct TypeMismatchError { + pub expected: String, + pub actual: String, } -impl From for Error { - fn from(_err: option::NoneError) -> Self { - ErrorKind::MissingShapeError.into() - } -} +#[derive(Debug, Fail)] +#[fail(display = "Missing NDArray shape.")] +pub struct MissingShapeError; diff --git a/rust/frontend/src/function.rs b/rust/frontend/src/function.rs index fa6bed141076..f0fbcbe67e25 100644 --- a/rust/frontend/src/function.rs +++ b/rust/frontend/src/function.rs @@ -15,14 +15,20 @@ use std::{ sync::Mutex, }; -use crate::{ts, ErrorKind, Module, Result, TVMArgValue, TVMRetValue, TVMTypeCode, TVMValue}; +use failure::Error; + +use crate::{ + errors, + ffi::{self, TVMValue}, + Module, TVMArgValue, TVMRetValue, +}; lazy_static! { static ref GLOBAL_FUNCTIONS: Mutex>> = { let mut out_size = 0 as c_int; let name = ptr::null_mut() as *mut c_char; let mut out_array = name as *mut _; - check_call!(ts::TVMFuncListGlobalNames( + check_call!(ffi::TVMFuncListGlobalNames( &mut out_size as *mut _, &mut out_array )); @@ -37,17 +43,14 @@ lazy_static! { } /// Wrapper around TVM function handle which includes `is_global` -/// indicating whether the function is global or not, `is_released` -/// to hint dropping the function handle and `is_cloned` showing +/// indicating whether the function is global or not, and `is_cloned` showing /// not to drop a cloned function from Rust side. /// The value of these fields can be accessed through their respective methods. #[derive(Debug, Hash)] pub struct Function { - pub(crate) handle: ts::TVMFunctionHandle, + pub(crate) handle: ffi::TVMFunctionHandle, // whether the registered function is global or not. is_global: bool, - // whether the function has been dropped from frontend or not. - is_released: bool, // whether the function has been cloned from frontend or not. is_cloned: bool, } @@ -56,29 +59,30 @@ unsafe impl Send for Function {} unsafe impl Sync for Function {} impl Function { - pub(crate) fn new(handle: ts::TVMFunctionHandle, is_global: bool, is_released: bool) -> Self { + pub(crate) fn new(handle: ffi::TVMFunctionHandle) -> Self { Function { handle: handle, - is_global: is_global, - is_released: is_released, + is_global: false, is_cloned: false, } } /// For a given function, it returns a function by name. - pub fn get>(name: S, is_global: bool) -> Option<&'static Function> { + pub fn get>(name: S) -> Option<&'static Function> { let mut globals = GLOBAL_FUNCTIONS.lock().unwrap(); globals.get_mut(name.as_ref()).and_then(|maybe_func| { if maybe_func.is_none() { let name = CString::new(name.as_ref()).unwrap(); - let mut handle = ptr::null_mut() as ts::TVMFunctionHandle; - check_call!(ts::TVMFuncGetGlobal( + let mut handle = ptr::null_mut() as ffi::TVMFunctionHandle; + check_call!(ffi::TVMFuncGetGlobal( name.as_ptr() as *const c_char, &mut handle as *mut _ )); - maybe_func.replace(Function::new( - handle, is_global, false, /* is_released */ - )); + maybe_func.replace(Function { + handle: handle, + is_global: true, + is_cloned: false, + }); } unsafe { std::mem::transmute::, Option<&'static Function>>( @@ -89,7 +93,7 @@ impl Function { } /// Returns the underlying TVM function handle. - pub fn handle(&self) -> ts::TVMFunctionHandle { + pub fn handle(&self) -> ffi::TVMFunctionHandle { self.handle } @@ -98,12 +102,6 @@ impl Function { self.is_global } - /// Returns `true` if the underlying TVM function has been released - /// from the frontend and `false` otherwise. - pub fn is_released(&self) -> bool { - self.is_released - } - /// Returns `true` if the underlying TVM function has been cloned /// from the frontend and `false` otherwise. pub fn is_cloned(&self) -> bool { @@ -113,24 +111,18 @@ impl Function { impl Clone for Function { fn clone(&self) -> Function { - if !self.is_released && !self.is_cloned { - Self { - handle: self.handle, - is_global: self.is_global, - is_released: self.is_released, - is_cloned: true, - } - } else { - Function::new(self.handle, self.is_global, self.is_released) + Self { + handle: self.handle, + is_global: self.is_global, + is_cloned: true, } } } impl Drop for Function { fn drop(&mut self) { - if !self.is_released && !self.is_global && !self.is_cloned { - check_call!(ts::TVMFuncFree(self.handle)); - self.is_released = true; + if !self.is_global && !self.is_cloned { + check_call!(ffi::TVMFuncFree(self.handle)); } } } @@ -138,17 +130,17 @@ impl Drop for Function { /// Function builder in order to create and call functions. /// /// *Note:* Currently TVM functions accept *at most* one return value. -#[derive(Debug, Clone, Default)] +#[derive(Default)] pub struct Builder<'a, 'm> { pub func: Option<&'m Function>, - pub arg_buf: Option]>>, + pub arg_buf: Vec>, pub ret_buf: Option, } impl<'a, 'm> Builder<'a, 'm> { pub fn new( func: Option<&'m Function>, - arg_buf: Option]>>, + arg_buf: Vec>, ret_buf: Option, ) -> Self { Self { @@ -158,123 +150,66 @@ impl<'a, 'm> Builder<'a, 'm> { } } - pub fn get_function(&mut self, name: &'m str, is_global: bool) -> &mut Self { - self.func = Function::get(name, is_global); + pub fn get_function(&mut self, name: &'m str) -> &mut Self { + self.func = Function::get(name); self } /// Pushes a [`TVMArgValue`] into the function argument buffer. - pub fn arg<'b, T: ?Sized>(&mut self, arg: &'b T) -> &mut Self + pub fn arg(&mut self, arg: &'a T) -> &mut Self where - TVMValue: From<&'b T>, - TVMTypeCode: From<&'b T>, + TVMArgValue<'a>: From<&'a T>, { - let tvm_arg = TVMArgValue::from(arg); - if self.arg_buf.is_none() { - self.arg_buf = Some(Box::new([tvm_arg])); - } else { - let new_arg_buf = self.arg_buf.take().map(|bbuf| { - let mut new_arg_buf = Vec::from(bbuf); - new_arg_buf.push(tvm_arg); - let new_len = new_arg_buf.len(); - new_arg_buf.truncate(new_len); - new_arg_buf.into_boxed_slice() - }); - self.arg_buf = new_arg_buf; - } + self.arg_buf.push(arg.into()); self } /// Pushes multiple [`TVMArgValue`]s into the function argument buffer. - pub fn args<'b, T: 'b + ?Sized, I>(&mut self, args: I) -> &mut Self + pub fn args(&mut self, args: I) -> &mut Self where - I: IntoIterator, - TVMValue: From<&'b T>, - TVMTypeCode: From<&'b T>, + I: IntoIterator, + TVMArgValue<'a>: From<&'a T>, { - for arg in args { + args.into_iter().for_each(|arg| { self.arg(&arg); - } + }); self } /// Sets an output for a function that requirs a mutable output to be provided. /// See the `basics` in tests for an example. - pub fn set_output<'b, T: 'b + ?Sized>(&mut self, arg: &'b mut T) -> Result<&mut Self> + pub fn set_output(&mut self, ret: T) -> &mut Self where - TVMValue: From<&'b T>, - TVMTypeCode: From<&'b T>, + TVMRetValue: From, { - if self.ret_buf.is_none() { - let tvm_ret = - unsafe { TVMRetValue::from_tvm_value(TVMValue::from(arg), TVMTypeCode::from(arg)) }; - self.ret_buf = Some(tvm_ret); - } else { - bail!(ErrorKind::AtMostOneReturn) - } - Ok(self) + self.ret_buf = Some(ret.into()); + self } /// Calls the function that created from `Builder`. - pub fn invoke(&mut self) -> Result { - self.clone()(()) - } -} - -impl<'a, 'm> FnOnce<((),)> for Builder<'a, 'm> { - type Output = Result; - extern "rust-call" fn call_once(self, _: ((),)) -> Self::Output { - if self.func.is_none() { - bail!("{}", ErrorKind::FunctionNotFound); - } - - let mut ret_val = unsafe { mem::uninitialized::() }; - let mut ret_type_code = 0 as c_int; - if self.arg_buf.is_some() { - let arg_buf = self.arg_buf?; - let mut num_args = arg_buf.len(); - let mut values = arg_buf - .iter() - .map(|tav| tav.value.inner) - .collect::>(); - let mut tcodes = arg_buf - .iter() - .map(|tav| tav.type_code as c_int) - .collect::>(); - - if self.ret_buf.is_some() { - num_args = num_args + 1; - let ret_buf = self.ret_buf?; - let (ret_val, ret_type_code) = TVMRetValue::into_tvm_value(ret_buf); - values.append(&mut vec![ret_val.inner]); - tcodes.append(&mut vec![ret_type_code as c_int]); - } - - values.truncate(num_args); - tcodes.truncate(num_args); - check_call!(ts::TVMFuncCall( - self.func?.handle, - values.as_mut_ptr(), - tcodes.as_mut_ptr(), - num_args as c_int, - &mut ret_val as *mut _, - &mut ret_type_code as *mut _ - )); - } else { - check_call!(ts::TVMFuncCall( - self.func?.handle, - ptr::null_mut(), - ptr::null_mut(), - 0 as c_int, - &mut ret_val as *mut _, - &mut ret_type_code as *mut _ - )); - } + pub fn invoke(&mut self) -> Result { + #![allow(unused_unsafe)] + ensure!(self.func.is_some(), errors::FunctionNotFoundError); + + let num_args = self.arg_buf.len(); + let (mut values, mut type_codes): (Vec, Vec) = self + .arg_buf + .iter() + .map(|tvm_arg| (tvm_arg.value, tvm_arg.type_code as ffi::TVMTypeCode)) + .unzip(); + + let mut ret_val = unsafe { std::mem::uninitialized::() }; + let mut ret_type_code = 0; + check_call!(ffi::TVMFuncCall( + self.func.ok_or(errors::FunctionNotFoundError)?.handle, + values.as_mut_ptr(), + type_codes.as_mut_ptr() as *mut i32, + num_args as c_int, + &mut ret_val as *mut _, + &mut ret_type_code as *mut _ + )); - let ret = unsafe { - TVMRetValue::from_tvm_value(TVMValue::new(ret_val), (ret_type_code as i64).into()) - }; - Ok(ret) + Ok(unsafe { TVMRetValue::from_tvm_value(ret_val, ret_type_code as i64) }) } } @@ -282,46 +217,44 @@ impl<'a, 'm> FnOnce<((),)> for Builder<'a, 'm> { /// TVM functions. impl<'a, 'm> From<&'m Function> for Builder<'a, 'm> { fn from(func: &'m Function) -> Self { - Builder::new(Some(func), None, None) + Builder::new(Some(func), Vec::new(), None) } } /// Converts a mutable reference of a [`Module`] to [`Builder`]. impl<'a, 'm> From<&'m mut Module> for Builder<'a, 'm> { fn from(module: &'m mut Module) -> Self { - Builder::new(module.entry(), None, None) + Builder::new(module.entry(), Vec::new(), None) } } unsafe extern "C" fn tvm_callback( - args: *mut ts::TVMValue, + args: *mut ffi::TVMValue, type_codes: *mut c_int, num_args: c_int, - ret: ts::TVMRetValueHandle, + ret: ffi::TVMRetValueHandle, fhandle: *mut c_void, ) -> c_int { // turning off the incorrect linter complaints - #![allow(unused_assignments)] + #![allow(unused_assignments, unused_unsafe)] let len = num_args as usize; let args_list = slice::from_raw_parts_mut(args, len); let type_codes_list = slice::from_raw_parts_mut(type_codes, len); let mut local_args: Vec = Vec::new(); - let mut value = mem::uninitialized::(); + let mut value = mem::uninitialized::(); let mut tcode = mem::uninitialized::(); - let rust_fn = mem::transmute::<*mut c_void, fn(&[TVMArgValue]) -> Result>(fhandle); + let rust_fn = + mem::transmute::<*mut c_void, fn(&[TVMArgValue]) -> Result>(fhandle); for i in 0..len { value = args_list[i]; tcode = type_codes_list[i]; - if tcode == ts::TVMTypeCode_kNodeHandle as c_int - || tcode == ts::TVMTypeCode_kFuncHandle as c_int - || tcode == ts::TVMTypeCode_kModuleHandle as c_int + if tcode == ffi::TVMTypeCode_kNodeHandle as c_int + || tcode == ffi::TVMTypeCode_kFuncHandle as c_int + || tcode == ffi::TVMTypeCode_kModuleHandle as c_int { - check_call!(ts::TVMCbArgToReturn(&mut value as *mut _, tcode)); + check_call!(ffi::TVMCbArgToReturn(&mut value as *mut _, tcode)); } - local_args.push(TVMArgValue::new( - TVMValue::new(value), - (tcode as i64).into(), - )); + local_args.push(TVMArgValue::new(value.into(), (tcode as i64).into())); } let rv = match rust_fn(local_args.as_slice()) { @@ -332,10 +265,9 @@ unsafe extern "C" fn tvm_callback( } }; - let (ret_val, ret_tcode) = TVMRetValue::into_tvm_value(rv); - let mut ret_val = ret_val.inner; + let (mut ret_val, ret_tcode) = rv.into_tvm_value(); let mut ret_type_code = ret_tcode as c_int; - check_call!(ts::TVMCFuncSetReturn( + check_call!(ffi::TVMCFuncSetReturn( ret, &mut ret_val as *mut _, &mut ret_type_code as *mut _, @@ -345,24 +277,25 @@ unsafe extern "C" fn tvm_callback( } unsafe extern "C" fn tvm_callback_finalizer(fhandle: *mut c_void) { - let rust_fn = mem::transmute::<*mut c_void, fn(&[TVMArgValue]) -> Result>(fhandle); + let rust_fn = + mem::transmute::<*mut c_void, fn(&[TVMArgValue]) -> Result>(fhandle); mem::drop(rust_fn); } -fn convert_to_tvm_func(f: fn(&[TVMArgValue]) -> Result) -> Function { - let mut fhandle = ptr::null_mut() as ts::TVMFunctionHandle; - let resource_handle = f as *mut fn(&[TVMArgValue]) -> Result; - check_call!(ts::TVMFuncCreateFromCFunc( +fn convert_to_tvm_func(f: fn(&[TVMArgValue]) -> Result) -> Function { + let mut fhandle = ptr::null_mut() as ffi::TVMFunctionHandle; + let resource_handle = f as *mut fn(&[TVMArgValue]) -> Result; + check_call!(ffi::TVMFuncCreateFromCFunc( Some(tvm_callback), resource_handle as *mut c_void, Some(tvm_callback_finalizer), &mut fhandle as *mut _ )); - Function::new(fhandle, false, false) + Function::new(fhandle) } /// Registers a Rust function with signature -/// `fn(&[TVMArgValue]) -> Result` +/// `fn(&[TVMArgValue]) -> Result` /// as a **global TVM packed function** from frontend to TVM backend. /// /// Use [`register_global_func`] if overriding an existing global TVM function @@ -373,7 +306,7 @@ fn convert_to_tvm_func(f: fn(&[TVMArgValue]) -> Result) -> Function /// ``` /// use std::convert::TryInto; /// -/// fn sum(args: &[TVMArgValue]) -> Result { +/// fn sum(args: &[TVMArgValue]) -> Result { /// let mut ret = 0i64; /// for arg in args.iter() { /// let arg: i64 = arg.try_into()?; @@ -391,18 +324,17 @@ fn convert_to_tvm_func(f: fn(&[TVMArgValue]) -> Result) -> Function /// assert_eq!(ret, 60); /// ``` pub fn register>( - f: fn(&[TVMArgValue]) -> Result, + f: fn(&[TVMArgValue]) -> Result, name: S, override_: bool, -) -> Result<()> { +) -> Result<(), Error> { let func = convert_to_tvm_func(f); let name = CString::new(name.as_ref())?; - check_call!(ts::TVMFuncRegisterGlobal( - name.as_ref().as_ptr() as *const c_char, + check_call!(ffi::TVMFuncRegisterGlobal( + name.into_raw(), func.handle(), override_ as c_int )); - mem::forget(name); Ok(()) } @@ -416,7 +348,7 @@ pub fn register>( /// use std::convert::TryInto; /// /// register_global_func! { -/// fn sum(args: &[TVMArgValue]) -> Result { +/// fn sum(args: &[TVMArgValue]) -> Result { /// let mut ret = 0f64; /// for arg in args.iter() { /// let arg: f64 = arg.try_into()?; @@ -437,12 +369,12 @@ pub fn register>( macro_rules! register_global_func { { $(#[$m:meta])* - fn $fn_name:ident($args:ident : &[TVMArgValue]) -> Result { + fn $fn_name:ident($args:ident : &[TVMArgValue]) -> Result { $($code:tt)* } } => {{ $(#[$m])* - fn $fn_name($args: &[TVMArgValue]) -> Result { + fn $fn_name($args: &[TVMArgValue]) -> Result { $($code)* } @@ -496,17 +428,17 @@ mod tests { #[test] fn get_fn() { - assert!(Function::get(CANARY, true).is_some()); - assert!(Function::get("does not exists!", false).is_none()); + assert!(Function::get(CANARY).is_some()); + assert!(Function::get("does not exists!").is_none()); } #[test] fn provide_args() { + let str_arg = CString::new("test").unwrap(); let mut func = Builder::default(); - func.get_function("tvm.graph_runtime.remote_create", true) + func.get_function("tvm.graph_runtime.remote_create") .args(&[10, 20]) - .arg(&"test".to_owned()); - assert!(func.arg_buf.is_some()); - assert_eq!(func.arg_buf.take().map(|bv| Vec::from(bv).len()), Some(3)); + .arg(&str_arg); + assert_eq!(func.arg_buf.len(), 3); } } diff --git a/rust/frontend/src/lib.rs b/rust/frontend/src/lib.rs index 6e15e4f8d046..a773b2735d9c 100644 --- a/rust/frontend/src/lib.rs +++ b/rust/frontend/src/lib.rs @@ -11,32 +11,36 @@ //! //! Checkout the `examples` repository for more details. -#![crate_name = "tvm_frontend"] -#![recursion_limit = "1024"] -#![allow(non_camel_case_types, unused_unsafe)] -#![feature( - try_from, - try_trait, - fn_traits, - unboxed_closures, - box_syntax, - option_replace -)] +#![feature(box_syntax)] #[macro_use] -extern crate error_chain; -extern crate tvm_common as common; +extern crate failure; #[macro_use] extern crate lazy_static; extern crate ndarray as rust_ndarray; extern crate num_traits; +extern crate tvm_common; use std::{ ffi::{CStr, CString}, str, }; -use crate::common::ffi::ts; +use failure::Error; + +pub use crate::{ + bytearray::TVMByteArray, + context::{TVMContext, TVMDeviceType}, + errors::*, + function::Function, + module::Module, + ndarray::NDArray, + tvm_common::{ + errors as common_errors, + ffi::{self, TVMType}, + packed_func::{TVMArgValue, TVMRetValue}, + }, +}; // Macro to check the return call to TVM runtime shared library. macro_rules! check_call { @@ -50,7 +54,7 @@ macro_rules! check_call { /// Gets the last error message. pub fn get_last_error() -> &'static str { unsafe { - match CStr::from_ptr(ts::TVMGetLastError()).to_str() { + match CStr::from_ptr(ffi::TVMGetLastError()).to_str() { Ok(s) => s, Err(_) => "Invalid UTF-8 message", } @@ -60,7 +64,7 @@ pub fn get_last_error() -> &'static str { pub(crate) fn set_last_error(err: &Error) { let c_string = CString::new(err.to_string()).unwrap(); unsafe { - ts::TVMAPISetLastError(c_string.as_ptr()); + ffi::TVMAPISetLastError(c_string.as_ptr()); } } @@ -71,27 +75,11 @@ pub mod context; pub mod errors; pub mod module; pub mod ndarray; -pub mod ty; pub mod value; -pub use crate::{ - bytearray::TVMByteArray, - common::{ - errors as common_errors, - ty::TVMTypeCode, - value::{TVMArgValue, TVMRetValue, TVMValue}, - }, - context::{TVMContext, TVMDeviceType}, - errors::*, - function::Function, - module::Module, - ndarray::NDArray, - ty::TVMType, -}; - /// Outputs the current TVM version. pub fn version() -> &'static str { - match str::from_utf8(ts::TVM_VERSION) { + match str::from_utf8(ffi::TVM_VERSION) { Ok(s) => s, Err(_) => "Invalid UTF-8 string", } @@ -108,8 +96,8 @@ mod tests { #[test] fn set_error() { - let err = ErrorKind::EmptyArray; + let err = errors::EmptyArrayError; set_last_error(&err.into()); - assert_eq!(get_last_error().trim(), ErrorKind::EmptyArray.to_string()); + assert_eq!(get_last_error().trim(), errors::EmptyArrayError.to_string()); } } diff --git a/rust/frontend/src/module.rs b/rust/frontend/src/module.rs index c12d9d48cf13..9c27387520dc 100644 --- a/rust/frontend/src/module.rs +++ b/rust/frontend/src/module.rs @@ -8,30 +8,27 @@ use std::{ ptr, }; -use crate::ts; +use failure::Error; +use tvm_common::ffi; -use crate::{function::Function, ErrorKind, Result}; +use crate::{errors, function::Function}; const ENTRY_FUNC: &'static str = "__tvm_main__"; /// Wrapper around TVM module handle which contains an entry function. /// The entry function can be applied to an imported module through [`entry_func`]. -/// Also [`is_released`] shows whether the module is dropped or not. /// /// [`entry_func`]:struct.Module.html#method.entry_func -/// [`is_released`]:struct.Module.html#method.is_released #[derive(Debug, Clone)] pub struct Module { - pub(crate) handle: ts::TVMModuleHandle, - is_released: bool, + pub(crate) handle: ffi::TVMModuleHandle, entry_func: Option, } impl Module { - pub(crate) fn new(handle: ts::TVMModuleHandle, is_released: bool) -> Self { + pub(crate) fn new(handle: ffi::TVMModuleHandle) -> Self { Self { handle, - is_released, entry_func: None, } } @@ -44,62 +41,67 @@ impl Module { } /// Gets a function by name from a registered module. - pub fn get_function(&self, name: &str, query_import: bool) -> Result { + pub fn get_function(&self, name: &str, query_import: bool) -> Result { let name = CString::new(name)?; - let mut fhandle = ptr::null_mut() as ts::TVMFunctionHandle; - check_call!(ts::TVMModGetFunction( + let mut fhandle = ptr::null_mut() as ffi::TVMFunctionHandle; + check_call!(ffi::TVMModGetFunction( self.handle, name.as_ptr() as *const c_char, query_import as c_int, &mut fhandle as *mut _ )); - if fhandle.is_null() { - bail!(ErrorKind::NullHandle(format!("{}", name.into_string()?))) - } else { - Ok(Function::new(fhandle, false, false)) - } + ensure!( + !fhandle.is_null(), + errors::NullHandleError { + name: format!("{}", name.into_string()?) + } + ); + Ok(Function::new(fhandle)) } /// Imports a dependent module such as `.ptx` for gpu. pub fn import_module(&self, dependent_module: Module) { - check_call!(ts::TVMModImport(self.handle, dependent_module.handle)) + check_call!(ffi::TVMModImport(self.handle, dependent_module.handle)) } /// Loads a module shared library from path. - pub fn load>(path: &P) -> Result { - let ext = path.as_ref().extension()?.to_str()?; - let func = Function::get("module._LoadFromFile", true /* is_global */) - .expect("API function always exists"); - let ret: Module = call_packed!(func, path.as_ref().to_str()?, ext)?.try_into()?; + pub fn load>(path: &P) -> Result { + let ext = CString::new( + path.as_ref() + .extension() + .unwrap_or(std::ffi::OsStr::new("")) + .to_str() + .ok_or_else(|| { + format_err!("Bad module load path: `{}`.", path.as_ref().display()) + })?, + )?; + let func = Function::get("module._LoadFromFile").expect("API function always exists"); + let cpath = + CString::new(path.as_ref().to_str().ok_or_else(|| { + format_err!("Bad module load path: `{}`.", path.as_ref().display()) + })?)?; + let ret: Module = call_packed!(func, &cpath, &ext)?.try_into()?; Ok(ret) } /// Checks if a target device is enabled for a module. pub fn enabled(&self, target: &str) -> bool { - let func = Function::get("module._Enabled", true /* is_global */) - .expect("API function always exists"); + let func = Function::get("module._Enabled").expect("API function always exists"); // `unwrap` is safe here because if there is any error during the // function call, it would occur in `call_packed!`. - let ret: i64 = call_packed!(func, target).unwrap().try_into().unwrap(); + let tgt = CString::new(target).unwrap(); + let ret: i64 = call_packed!(func, &tgt).unwrap().try_into().unwrap(); ret != 0 } /// Returns the underlying module handle. - pub fn handle(&self) -> ts::TVMModuleHandle { + pub fn handle(&self) -> ffi::TVMModuleHandle { self.handle } - - /// Returns true if the underlying module has been dropped and false otherwise. - pub fn is_released(&self) -> bool { - self.is_released - } } impl Drop for Module { fn drop(&mut self) { - if !self.is_released { - check_call!(ts::TVMModFree(self.handle)); - self.is_released = true; - } + check_call!(ffi::TVMModFree(self.handle)); } } diff --git a/rust/frontend/src/ndarray.rs b/rust/frontend/src/ndarray.rs index 44dfcca3b320..1939c92c0f0b 100644 --- a/rust/frontend/src/ndarray.rs +++ b/rust/frontend/src/ndarray.rs @@ -23,34 +23,34 @@ //! [`copy_from_buffer`]:struct.NDArray.html#method.copy_from_buffer //! [`copy_to_ctx`]:struct.NDArray.html#method.copy_to_ctx -use std::{convert::TryFrom, mem, os::raw::c_int, ptr, slice}; +use std::{convert::TryFrom, mem, os::raw::c_int, ptr, slice, str::FromStr}; -use crate::rust_ndarray::{Array, ArrayD}; +use failure::Error; use num_traits::Num; +use rust_ndarray::{Array, ArrayD}; +use tvm_common::{ffi, TVMType}; -use crate::ts; - -use crate::{Error, ErrorKind, Result, TVMByteArray, TVMContext, TVMType}; +use crate::{errors, TVMByteArray, TVMContext}; /// See the [`module-level documentation`](../ndarray/index.html) for more details. /// /// Wrapper around TVM array handle. #[derive(Debug)] pub struct NDArray { - pub(crate) handle: ts::TVMArrayHandle, + pub(crate) handle: ffi::TVMArrayHandle, is_view: bool, } impl NDArray { - pub(crate) fn new(handle: ts::TVMArrayHandle, is_view: bool) -> Self { + pub(crate) fn new(handle: ffi::TVMArrayHandle) -> Self { NDArray { handle: handle, - is_view: is_view, + is_view: true, } } /// Returns the underlying array handle. - pub fn handle(&self) -> ts::TVMArrayHandle { + pub fn handle(&self) -> ffi::TVMArrayHandle { self.handle } @@ -99,12 +99,13 @@ impl NDArray { } /// Shows whether the underlying ndarray is contiguous in memory or not. - pub fn is_contiguous(&self) -> Result { + pub fn is_contiguous(&self) -> Result { Ok(match self.strides() { None => true, Some(strides) => { - // MissingShapeError in case shape is not determined - self.shape()? + // errors::MissingShapeError in case shape is not determined + self.shape() + .ok_or(errors::MissingShapeError)? .iter() .zip(strides) .rfold( @@ -138,14 +139,16 @@ impl NDArray { /// assert_eq!(ndarray.shape(), Some(shape)); /// assert_eq!(ndarray.to_vec::().unwrap(), data); /// ``` - pub fn to_vec(&self) -> Result> { - if self.shape().is_none() { - bail!("{}", ErrorKind::EmptyArray); - } - let earr = NDArray::empty(self.shape()?, TVMContext::cpu(0), self.dtype()); + pub fn to_vec(&self) -> Result, Error> { + ensure!(self.shape().is_some(), errors::EmptyArrayError); + let earr = NDArray::empty( + self.shape().ok_or(errors::MissingShapeError)?, + TVMContext::cpu(0), + self.dtype(), + ); let target = self.copy_to_ndarray(earr)?; let arr = unsafe { *(target.handle) }; - let sz = self.size()? as usize; + let sz = self.size().ok_or(errors::MissingShapeError)?; let mut v: Vec = Vec::with_capacity(sz * mem::size_of::()); unsafe { v.as_mut_ptr() @@ -156,7 +159,7 @@ impl NDArray { } /// Converts the NDArray to [`TVMByteArray`]. - pub fn to_bytearray(&self) -> Result { + pub fn to_bytearray(&self) -> Result { let v = self.to_vec::()?; Ok(TVMByteArray::from(&v)) } @@ -176,7 +179,7 @@ impl NDArray { /// *Note*: if something goes wrong during the copy, it will panic /// from TVM side. See `TVMArrayCopyFromBytes` in `include/tvm/runtime/c_runtime_api.h`. pub fn copy_from_buffer(&mut self, data: &mut [T]) { - check_call!(ts::TVMArrayCopyFromBytes( + check_call!(ffi::TVMArrayCopyFromBytes( self.handle, data.as_ptr() as *mut _, data.len() * mem::size_of::() @@ -184,27 +187,31 @@ impl NDArray { } /// Copies the NDArray to another target NDArray. - pub fn copy_to_ndarray(&self, target: NDArray) -> Result { + pub fn copy_to_ndarray(&self, target: NDArray) -> Result { if self.dtype() != target.dtype() { bail!( "{}", - ErrorKind::TypeMismatch( - format!("{}", self.dtype().to_string()), - format!("{}", target.dtype().to_string()), - ) + errors::TypeMismatchError { + expected: format!("{}", self.dtype().to_string()), + actual: format!("{}", target.dtype().to_string()), + } ); } - check_call!(ts::TVMArrayCopyFromTo( + check_call!(ffi::TVMArrayCopyFromTo( self.handle, target.handle, - ptr::null_mut() as ts::TVMStreamHandle + ptr::null_mut() as ffi::TVMStreamHandle )); Ok(target) } /// Copies the NDArray to a target context. - pub fn copy_to_ctx(&self, target: &TVMContext) -> Result { - let tmp = NDArray::empty(self.shape()?, target.clone(), self.dtype()); + pub fn copy_to_ctx(&self, target: &TVMContext) -> Result { + let tmp = NDArray::empty( + self.shape().ok_or(errors::MissingShapeError)?, + target.clone(), + self.dtype(), + ); let copy = self.copy_to_ndarray(tmp)?; Ok(copy) } @@ -214,28 +221,34 @@ impl NDArray { rnd: &ArrayD, ctx: TVMContext, dtype: TVMType, - ) -> Result { + ) -> Result { let mut shape = rnd.shape().to_vec(); let mut nd = NDArray::empty(&mut shape, ctx, dtype); let mut buf = Array::from_iter(rnd.into_iter().map(|&v| v as T)); - nd.copy_from_buffer(buf.as_slice_mut()?); + nd.copy_from_buffer( + buf.as_slice_mut() + .expect("Array from iter must be contiguous."), + ); Ok(nd) } /// Allocates and creates an empty NDArray given the shape, context and dtype. pub fn empty(shape: &[usize], ctx: TVMContext, dtype: TVMType) -> NDArray { - let mut handle = ptr::null_mut() as ts::TVMArrayHandle; - check_call!(ts::TVMArrayAlloc( + let mut handle = ptr::null_mut() as ffi::TVMArrayHandle; + check_call!(ffi::TVMArrayAlloc( shape.as_ptr() as *const i64, shape.len() as c_int, - dtype.inner.code as c_int, - dtype.inner.bits as c_int, - dtype.inner.lanes as c_int, + dtype.code as c_int, + dtype.bits as c_int, + dtype.lanes as c_int, ctx.device_type.0 as c_int, ctx.device_id as c_int, &mut handle as *mut _, )); - NDArray::new(handle, false) + NDArray { + handle, + is_view: false, + } } } @@ -243,23 +256,25 @@ macro_rules! impl_from_ndarray_rustndarray { ($type:ty, $type_name:tt) => { impl<'a> TryFrom<&'a NDArray> for ArrayD<$type> { type Error = Error; - fn try_from(nd: &NDArray) -> Result> { - if nd.shape().is_none() { - bail!("{}", ErrorKind::EmptyArray); - } - assert_eq!(nd.dtype(), TVMType::from($type_name), "Type mismatch"); - Ok(Array::from_shape_vec(&*nd.shape()?, nd.to_vec::<$type>()?)?) + fn try_from(nd: &NDArray) -> Result, Self::Error> { + ensure!(nd.shape().is_some(), errors::MissingShapeError); + assert_eq!(nd.dtype(), TVMType::from_str($type_name)?, "Type mismatch"); + Ok(Array::from_shape_vec( + &*nd.shape().ok_or(errors::MissingShapeError)?, + nd.to_vec::<$type>()?, + )?) } } impl<'a> TryFrom<&'a mut NDArray> for ArrayD<$type> { type Error = Error; - fn try_from(nd: &mut NDArray) -> Result> { - if nd.shape().is_none() { - bail!("{}", ErrorKind::EmptyArray); - } - assert_eq!(nd.dtype(), TVMType::from($type_name), "Type mismatch"); - Ok(Array::from_shape_vec(&*nd.shape()?, nd.to_vec::<$type>()?)?) + fn try_from(nd: &mut NDArray) -> Result, Self::Error> { + ensure!(nd.shape().is_some(), errors::MissingShapeError); + assert_eq!(nd.dtype(), TVMType::from_str($type_name)?, "Type mismatch"); + Ok(Array::from_shape_vec( + &*nd.shape().ok_or(errors::MissingShapeError)?, + nd.to_vec::<$type>()?, + )?) } } }; @@ -272,7 +287,7 @@ impl_from_ndarray_rustndarray!(f32, "float"); impl Drop for NDArray { fn drop(&mut self) { if !self.is_view { - check_call!(ts::TVMArrayFree(self.handle)); + check_call!(ffi::TVMArrayFree(self.handle)); } } } @@ -306,7 +321,7 @@ mod tests { fn basics() { let shape = &mut [1, 2, 3]; let ctx = TVMContext::cpu(0); - let ndarray = NDArray::empty(shape, ctx, TVMType::from("int32")); + let ndarray = NDArray::empty(shape, ctx, TVMType::from_str("int32").unwrap()); assert_eq!(ndarray.shape().unwrap(), shape); assert_eq!( ndarray.size().unwrap(), @@ -322,7 +337,7 @@ mod tests { let shape = &mut [4]; let mut data = vec![1i32, 2, 3, 4]; let ctx = TVMContext::cpu(0); - let mut ndarray = NDArray::empty(shape, ctx, TVMType::from("int32")); + let mut ndarray = NDArray::empty(shape, ctx, TVMType::from_str("int32").unwrap()); assert!(ndarray.to_vec::().is_ok()); ndarray.copy_from_buffer(&mut data); assert_eq!(ndarray.shape().unwrap(), shape); @@ -331,7 +346,11 @@ mod tests { assert!(ndarray.is_contiguous().is_ok()); assert_eq!(ndarray.byte_offset(), 0); let mut shape = vec![4]; - let e = NDArray::empty(&mut shape, TVMContext::cpu(0), TVMType::from("int32")); + let e = NDArray::empty( + &mut shape, + TVMContext::cpu(0), + TVMType::from_str("int32").unwrap(), + ); let nd = ndarray.copy_to_ndarray(e); assert!(nd.is_ok()); assert_eq!(nd.unwrap().to_vec::().unwrap(), data); @@ -343,9 +362,13 @@ mod tests { let mut shape = vec![4]; let mut data = vec![1f32, 2., 3., 4.]; let ctx = TVMContext::cpu(0); - let mut nd_float = NDArray::empty(&mut shape, ctx.clone(), TVMType::from("float32")); + let mut nd_float = NDArray::empty( + &mut shape, + ctx.clone(), + TVMType::from_str("float32").unwrap(), + ); nd_float.copy_from_buffer(&mut data); - let empty_int = NDArray::empty(&mut shape, ctx, TVMType::from("int32")); + let empty_int = NDArray::empty(&mut shape, ctx, TVMType::from_str("int32").unwrap()); nd_float.copy_to_ndarray(empty_int).unwrap(); } @@ -354,8 +377,12 @@ mod tests { let a = Array::from_shape_vec((2, 2), vec![1f32, 2., 3., 4.]) .unwrap() .into_dyn(); - let nd = - NDArray::from_rust_ndarray(&a, TVMContext::cpu(0), TVMType::from("float32")).unwrap(); + let nd = NDArray::from_rust_ndarray( + &a, + TVMContext::cpu(0), + TVMType::from_str("float32").unwrap(), + ) + .unwrap(); assert_eq!(nd.shape().unwrap(), &mut [2, 2]); let rnd: ArrayD = ArrayD::try_from(&nd).unwrap(); assert!(rnd.all_close(&a, 1e-8f32)); diff --git a/rust/frontend/src/ty.rs b/rust/frontend/src/ty.rs deleted file mode 100644 index 7e912a517e1d..000000000000 --- a/rust/frontend/src/ty.rs +++ /dev/null @@ -1,150 +0,0 @@ -//! This module implements the required conversions from Rust types to TVM types. -//! -//! In TVM frontend only conversions from Rust's 32-bits (POD) numeric types (i32, u32, f32) -//! and 64-bits pointers are supported. - -use std::{ - fmt::{self, Display, Formatter}, - ops::{Deref, DerefMut}, -}; - -use crate::ts; - -use crate::{Function, Module, NDArray, TVMByteArray, TVMContext, TVMDeviceType, TVMTypeCode}; - -macro_rules! impl_prim_type { - ($type:ty, $variant:ident) => { - impl From<$type> for TVMTypeCode { - fn from(_arg: $type) -> Self { - TVMTypeCode::$variant - } - } - - impl<'a> From<&'a $type> for TVMTypeCode { - fn from(_arg: &$type) -> Self { - TVMTypeCode::$variant - } - } - - impl<'a> From<&'a mut $type> for TVMTypeCode { - fn from(_arg: &mut $type) -> Self { - TVMTypeCode::$variant - } - } - }; -} - -impl_prim_type!(TVMDeviceType, kDLInt); -impl_prim_type!(TVMContext, kTVMContext); -impl_prim_type!(TVMType, kTVMType); -impl_prim_type!(Function, kFuncHandle); -impl_prim_type!(Module, kModuleHandle); -impl_prim_type!(NDArray, kArrayHandle); -impl_prim_type!(TVMByteArray, kBytes); - -/// See the [module-level documentation](../ty/index.html) for more details. -/// -/// Wrapper around underlying TVMType -#[derive(Debug, Copy, Clone, PartialEq, Eq)] -pub struct TVMType { - // inner fields are (code: u8, bits: u8, lanes: u16) - pub inner: ts::TVMType, -} - -impl TVMType { - pub(crate) fn new(type_code: u8, bits: u8, lanes: u16) -> Self { - TVMType { - inner: ts::TVMType { - code: type_code, - bits: bits, - lanes: lanes, - }, - } - } -} - -/// Implements TVMType conversion from `&str` of general format `{dtype}{bits}x{lanes}` -/// such as "int32", "float32" or with lane "float32x1". -impl<'a> From<&'a str> for TVMType { - fn from(type_str: &'a str) -> Self { - if type_str == "bool" { - return TVMType::new(1, 1, 1); - } - - let mut type_lanes = type_str.split("x"); - let typ = type_lanes.next().expect("Missing dtype"); - let lanes = type_lanes - .next() - .map(|l| u16::from_str_radix(l, 10).expect(&format!("Bad dtype lanes: {}", l))) - .unwrap_or(1); - let (type_name, bits) = match typ.find(char::is_numeric) { - Some(idx) => { - let (name, bits_str) = typ.split_at(idx); - ( - name, - u8::from_str_radix(bits_str, 10) - .expect(&format!("Bad dtype bits: {}", bits_str)), - ) - } - None => (typ, 32), - }; - - let type_code = match type_name { - "int" => 0, - "uint" => 1, - "float" => 2, - "handle" => 3, - _ => unimplemented!(), - }; - - TVMType::new(type_code, bits, lanes) - } -} - -impl Display for TVMType { - fn fmt(&self, f: &mut Formatter) -> fmt::Result { - let ts::TVMType { code, bits, lanes } = self.inner; - if bits == 1 && lanes == 1 { - return write!(f, "bool"); - } - let mut tcode_str = match code { - 0 => "int", - 1 => "uint", - 2 => "float", - 4 => "handle", - _ => "Unknown", - } - .to_string(); - - tcode_str += &bits.to_string(); - if lanes > 1 { - tcode_str += &format!("x{}", lanes.to_string()); - } - f.write_str(&tcode_str) - } -} - -impl From for ts::DLDataType { - fn from(dtype: TVMType) -> Self { - dtype.inner - } -} - -impl From for TVMType { - fn from(dtype: ts::DLDataType) -> Self { - Self::new(dtype.code, dtype.bits, dtype.lanes) - } -} - -impl Deref for TVMType { - type Target = ts::TVMType; - fn deref(&self) -> &Self::Target { - &self.inner - } -} - -impl DerefMut for TVMType { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.inner - } -} diff --git a/rust/frontend/src/value.rs b/rust/frontend/src/value.rs index 9fad7de4984a..eb62f10cabec 100644 --- a/rust/frontend/src/value.rs +++ b/rust/frontend/src/value.rs @@ -2,139 +2,87 @@ //! and their conversions needed for the types used in frontend crate. //! `TVMRetValue` is the owned version of `TVMPODValue`. -use std::{convert::TryFrom, mem, os::raw::c_void}; +use std::{convert::TryFrom, os::raw::c_void}; + +use failure::Error; +use tvm_common::{ + ensure_type, + ffi::{self, TVMValue}, +}; use crate::{ - common_errors::*, ts, Function, Module, NDArray, TVMArgValue, TVMByteArray, TVMContext, - TVMDeviceType, TVMRetValue, TVMType, TVMTypeCode, TVMValue, + common_errors::*, context::TVMContext, Function, Module, NDArray, TVMArgValue, TVMByteArray, + TVMRetValue, }; macro_rules! impl_tvm_val_from_handle { - ($($ty:ty),+) => { - $( - impl<'a> From<&'a $ty> for TVMValue { - fn from(arg: &$ty) -> Self { - let inner = ts::TVMValue { + ($ty:ident, $type_code:expr, $handle:ty) => { + impl<'a> From<&'a $ty> for TVMArgValue<'a> { + fn from(arg: &$ty) -> Self { + TVMArgValue { + value: TVMValue { v_handle: arg.handle as *mut _ as *mut c_void, - }; - Self::new(inner) + }, + type_code: $type_code as i64, + _lifetime: std::marker::PhantomData, } } - )+ - } -} - -impl_tvm_val_from_handle!(Module, Function, NDArray); - -impl<'a> From<&'a TVMType> for TVMValue { - fn from(ty: &TVMType) -> Self { - let inner = ts::TVMValue { v_type: ty.inner }; - Self::new(inner) - } -} - -impl<'a> From<&'a TVMContext> for TVMValue { - fn from(ctx: &TVMContext) -> Self { - let inner = ts::TVMValue { - v_ctx: ctx.clone().into(), - }; - Self::new(inner) - } -} - -impl<'a> From<&'a TVMDeviceType> for TVMValue { - fn from(dev: &TVMDeviceType) -> Self { - let inner = ts::TVMValue { - v_int64: dev.0 as i64, - }; - Self::new(inner) - } -} - -impl<'a> From<&'a TVMByteArray> for TVMValue { - fn from(barr: &TVMByteArray) -> Self { - let inner = ts::TVMValue { - v_handle: &barr.inner as *const ts::TVMByteArray as *mut c_void, - }; - Self::new(inner) - } -} + } -impl<'a, 'b> TryFrom<&'b TVMArgValue<'a>> for NDArray { - type Error = Error; - fn try_from(arg: &TVMArgValue<'a>) -> Result { - if arg.type_code == TVMTypeCode::kArrayHandle { - let handle = unsafe { arg.value.inner.v_handle }; - let arr_handle = unsafe { mem::transmute::<*mut c_void, ts::TVMArrayHandle>(handle) }; - Ok(Self::new(arr_handle, true)) - } else { - bail!(ErrorKind::TryFromTVMArgValueError( - stringify!(NDArray).to_string(), - arg.type_code.to_string() - )) + impl<'a> From<&'a mut $ty> for TVMArgValue<'a> { + fn from(arg: &mut $ty) -> Self { + TVMArgValue { + value: TVMValue { + v_handle: arg.handle as *mut _ as *mut c_void, + }, + type_code: $type_code as i64, + _lifetime: std::marker::PhantomData, + } + } } - } -} -impl<'a, 'b> TryFrom<&'b TVMArgValue<'a>> for Module { - type Error = Error; - fn try_from(arg: &TVMArgValue<'a>) -> Result { - if arg.type_code == TVMTypeCode::kModuleHandle { - let handle = unsafe { arg.value.inner.v_handle }; - Ok(Self::new(handle, false)) - } else { - bail!(ErrorKind::TryFromTVMArgValueError( - stringify!(Module).to_string(), - arg.type_code.to_string() - )) + impl<'a, 'v> TryFrom<&'a TVMArgValue<'v>> for $ty { + type Error = Error; + fn try_from(arg: &TVMArgValue<'v>) -> Result<$ty, Self::Error> { + ensure_type!(arg, $type_code); + Ok($ty::new(unsafe { arg.value.v_handle as $handle })) + } } - } -} -impl<'a, 'b> TryFrom<&'b TVMArgValue<'a>> for TVMByteArray { - type Error = Error; - fn try_from(arg: &TVMArgValue<'a>) -> Result { - if arg.type_code == TVMTypeCode::kBytes { - unsafe { - let barr_ptr = - mem::transmute::<*mut c_void, *mut ts::TVMByteArray>(arg.value.inner.v_handle); - Ok(Self::new(*barr_ptr)) + impl From<$ty> for TVMRetValue { + fn from(val: $ty) -> TVMRetValue { + TVMRetValue { + value: TVMValue { + v_handle: val.handle() as *mut c_void, + }, + box_value: box val, + type_code: $type_code as i64, + } } - } else { - bail!(ErrorKind::TryFromTVMArgValueError( - stringify!(TVMByteArray).to_string(), - arg.type_code.to_string() - )) } - } -} -impl<'a, 'b> TryFrom<&'b TVMArgValue<'a>> for TVMType { - type Error = Error; - fn try_from(arg: &TVMArgValue<'a>) -> Result { - if arg.type_code == TVMTypeCode::kTVMType { - let ty = unsafe { arg.value.inner.v_type }; - Ok(TVMType::from(ty)) - } else { - bail!(ErrorKind::TryFromTVMArgValueError( - stringify!(TVMType).to_string(), - arg.type_code.to_string() - )) + impl TryFrom for $ty { + type Error = Error; + fn try_from(ret: TVMRetValue) -> Result<$ty, Self::Error> { + ensure_type!(ret, $type_code); + Ok($ty::new(unsafe { ret.value.v_handle as $handle })) + } } - } + }; } -impl<'a, 'b> TryFrom<&'b TVMArgValue<'a>> for TVMContext { - type Error = Error; - fn try_from(arg: &TVMArgValue<'a>) -> Result { - if arg.type_code == TVMTypeCode::kTVMContext { - let ty = unsafe { arg.value.inner.v_ctx }; - Ok(TVMContext::from(ty)) - } else { - bail!(ErrorKind::TryFromTVMArgValueError( - stringify!(TVMContext).to_string(), - arg.type_code.to_string() - )) +impl_tvm_val_from_handle!( + Function, + ffi::TVMTypeCode_kFuncHandle, + ffi::TVMFunctionHandle +); +impl_tvm_val_from_handle!(Module, ffi::TVMTypeCode_kModuleHandle, ffi::TVMModuleHandle); +impl_tvm_val_from_handle!(NDArray, ffi::TVMTypeCode_kArrayHandle, ffi::TVMArrayHandle); + +impl<'a> From<&'a TVMByteArray> for TVMValue { + fn from(barr: &TVMByteArray) -> Self { + TVMValue { + v_handle: &barr.inner as *const ffi::TVMByteArray as *mut c_void, } } } @@ -144,78 +92,43 @@ macro_rules! impl_boxed_ret_value { impl From<$type> for TVMRetValue { fn from(val: $type) -> Self { TVMRetValue { - prim_value: 0, + value: TVMValue { v_int64: 0 }, box_value: box val, - type_code: $code, + type_code: $code as i64, } } } impl TryFrom for $type { type Error = Error; - fn try_from(ret: TVMRetValue) -> Result<$type> { + fn try_from(ret: TVMRetValue) -> Result<$type, Self::Error> { if let Ok(val) = ret.box_value.downcast::<$type>() { Ok(*val) } else { - bail!(ErrorKind::TryFromTVMRetValueError( - stringify!($type).to_string(), - ret.type_code.to_string() - )) + bail!(ValueDowncastError::new($code as i64, ret.type_code as i64)) } } } }; } -impl_boxed_ret_value!(TVMType, TVMTypeCode::kTVMType); -impl_boxed_ret_value!(TVMContext, TVMTypeCode::kTVMContext); -impl_boxed_ret_value!(TVMByteArray, TVMTypeCode::kBytes); - -impl TryFrom for Module { - type Error = Error; - fn try_from(ret: TVMRetValue) -> Result { - if let Ok(handle) = ret.box_value.downcast::() { - Ok(Module::new(*handle, false)) - } else { - bail!(ErrorKind::TryFromTVMRetValueError( - stringify!(TVMTypeCode::kModuleHandle).to_string(), - ret.type_code.to_string() - )) - } - } -} - -impl TryFrom for Function { - type Error = Error; - fn try_from(ret: TVMRetValue) -> Result { - if let Ok(handle) = ret.box_value.downcast::() { - Ok(Function::new(*handle, false, false)) - } else { - bail!(ErrorKind::TryFromTVMRetValueError( - stringify!(TVMTypeCode::kFuncHandle).to_string(), - ret.type_code.to_string() - )) - } - } -} +impl_boxed_ret_value!(TVMContext, ffi::TVMTypeCode_kTVMContext); +impl_boxed_ret_value!(TVMByteArray, ffi::TVMTypeCode_kBytes); -impl TryFrom for NDArray { +impl<'a, 'v> TryFrom<&'a TVMArgValue<'v>> for TVMByteArray { type Error = Error; - fn try_from(ret: TVMRetValue) -> Result { - if let Ok(handle) = ret.box_value.downcast::() { - Ok(NDArray::new(*handle, false)) - } else { - bail!(ErrorKind::TryFromTVMRetValueError( - stringify!(TVMTypeCode::kArrayHandle).to_string(), - ret.type_code.to_string() - )) - } + fn try_from(arg: &TVMArgValue<'v>) -> Result { + ensure_type!(arg, ffi::TVMTypeCode_kBytes); + Ok(TVMByteArray::new(unsafe { + *(arg.value.v_handle as *mut ffi::TVMByteArray) + })) } } #[cfg(test)] mod tests { use super::*; - use std::convert::TryInto; + use std::{convert::TryInto, str::FromStr}; + use tvm_common::ffi::TVMType; #[test] fn bytearray() { @@ -227,7 +140,7 @@ mod tests { #[test] fn ty() { - let t = TVMType::from("int32"); + let t = TVMType::from_str("int32").unwrap(); let tvm: TVMType = TVMRetValue::from(t).try_into().unwrap(); assert_eq!(tvm, t); } diff --git a/rust/frontend/tests/basics/src/main.rs b/rust/frontend/tests/basics/src/main.rs index 69b948e9117d..55c537bfd362 100644 --- a/rust/frontend/tests/basics/src/main.rs +++ b/rust/frontend/tests/basics/src/main.rs @@ -1,6 +1,8 @@ extern crate ndarray as rust_ndarray; extern crate tvm_frontend as tvm; +use std::str::FromStr; + use tvm::*; fn main() { @@ -12,7 +14,7 @@ fn main() { } else { (TVMContext::gpu(0), "gpu") }; - let dtype = TVMType::from("float32"); + let dtype = TVMType::from_str("float32").unwrap(); let mut arr = NDArray::empty(shape, ctx, dtype); arr.copy_from_buffer(data.as_mut_slice()); let mut ret = NDArray::empty(shape, ctx, dtype); @@ -26,8 +28,7 @@ fn main() { function::Builder::from(&mut fadd) .arg(&arr) .arg(&arr) - .set_output(&mut ret) - .unwrap() + .arg(&mut ret) .invoke() .unwrap(); diff --git a/rust/frontend/tests/callback/src/bin/array.rs b/rust/frontend/tests/callback/src/bin/array.rs index 81dcadc30851..e77ea435a057 100644 --- a/rust/frontend/tests/callback/src/bin/array.rs +++ b/rust/frontend/tests/callback/src/bin/array.rs @@ -1,4 +1,3 @@ -#![feature(extern_crate_item_prelude, try_from)] #![allow(unused_imports)] extern crate ndarray as rust_ndarray; @@ -6,17 +5,23 @@ extern crate ndarray as rust_ndarray; extern crate tvm_frontend as tvm; use rust_ndarray::ArrayD; -use std::convert::{TryFrom, TryInto}; +use std::{ + convert::{TryFrom, TryInto}, + str::FromStr, +}; -use tvm::*; +use tvm::{errors::Error, *}; fn main() { register_global_func! { - fn sum(args: &[TVMArgValue]) -> Result { + fn sum(args: &[TVMArgValue]) -> Result { let mut ret = 0f32; let shape = &mut [2]; for arg in args.iter() { - let e = NDArray::empty(shape, TVMContext::cpu(0), TVMType::from("float32")); + let e = NDArray::empty( + shape, TVMContext::cpu(0), + TVMType::from_str("float32").unwrap() + ); let arg: NDArray = arg.try_into()?; let arr = arg.copy_to_ndarray(e)?; let rnd: ArrayD = ArrayD::try_from(&arr)?; @@ -28,12 +33,16 @@ fn main() { let shape = &mut [2]; let mut data = vec![3f32, 4.0]; - let mut arr = NDArray::empty(shape, TVMContext::cpu(0), TVMType::from("float32")); + let mut arr = NDArray::empty( + shape, + TVMContext::cpu(0), + TVMType::from_str("float32").unwrap(), + ); arr.copy_from_buffer(data.as_mut_slice()); let mut registered = function::Builder::default(); let ret: f32 = registered - .get_function("sum", true) + .get_function("sum") .arg(&arr) .arg(&arr) .invoke() diff --git a/rust/frontend/tests/callback/src/bin/error.rs b/rust/frontend/tests/callback/src/bin/error.rs index f40f0f157815..24a1f07e9764 100644 --- a/rust/frontend/tests/callback/src/bin/error.rs +++ b/rust/frontend/tests/callback/src/bin/error.rs @@ -1,4 +1,4 @@ -#![feature(extern_crate_item_prelude, panic_info_message)] +#![feature(panic_info_message)] #![allow(unused_imports)] use std::panic; @@ -6,20 +6,20 @@ use std::panic; #[macro_use] extern crate tvm_frontend as tvm; -use tvm::*; +use tvm::{errors::Error, *}; fn main() { register_global_func! { - fn error(_args: &[TVMArgValue]) -> Result { - Err(ErrorKind::TypeMismatch( - format!("{}", "i64".to_string()), - format!("{}", "f64".to_string()), - ).into()) + fn error(_args: &[TVMArgValue]) -> Result { + Err(errors::TypeMismatchError{ + expected: "i64".to_string(), + actual: "f64".to_string(), + }.into()) } } let mut registered = function::Builder::default(); - registered.get_function("error", true); + registered.get_function("error"); assert!(registered.func.is_some()); registered.args(&[10, 20]); diff --git a/rust/frontend/tests/callback/src/bin/float.rs b/rust/frontend/tests/callback/src/bin/float.rs index 3070552843d7..a26487be8678 100644 --- a/rust/frontend/tests/callback/src/bin/float.rs +++ b/rust/frontend/tests/callback/src/bin/float.rs @@ -1,26 +1,25 @@ -#![feature(extern_crate_item_prelude, try_from)] #![allow(unused_imports)] #[macro_use] extern crate tvm_frontend as tvm; use std::convert::TryInto; -use tvm::*; +use tvm::{errors::Error, *}; fn main() { register_global_func! { - fn sum(args: &[TVMArgValue]) -> Result { + fn sum(args: &[TVMArgValue]) -> Result { let mut ret = 0.0; - for arg in args.iter() { + for arg in args.into_iter() { let val: f64 = arg.try_into()?; ret += val; } - Ok(TVMRetValue::from(&ret)) + Ok(TVMRetValue::from(ret)) } } let mut registered = function::Builder::default(); - registered.get_function("sum", true); + registered.get_function("sum"); assert!(registered.func.is_some()); let ret: f64 = registered .args(&[10.0f64, 20.0, 30.0]) diff --git a/rust/frontend/tests/callback/src/bin/int.rs b/rust/frontend/tests/callback/src/bin/int.rs index 30188222054a..591f95a660a1 100644 --- a/rust/frontend/tests/callback/src/bin/int.rs +++ b/rust/frontend/tests/callback/src/bin/int.rs @@ -1,25 +1,24 @@ -#![feature(extern_crate_item_prelude, try_from)] #![allow(unused_imports)] extern crate tvm_frontend as tvm; use std::convert::TryInto; -use tvm::*; +use tvm::{errors::Error, *}; fn main() { - fn sum(args: &[TVMArgValue]) -> Result { + fn sum(args: &[TVMArgValue]) -> Result { let mut ret = 0i64; for arg in args.iter() { let val: i64 = arg.try_into()?; ret += val; } - Ok(TVMRetValue::from(&ret)) + Ok(TVMRetValue::from(ret)) } tvm::function::register(sum, "mysum".to_owned(), false).unwrap(); let mut registered = function::Builder::default(); - registered.get_function("mysum", true); + registered.get_function("mysum"); assert!(registered.func.is_some()); let ret: i64 = registered .args(&[10, 20, 30]) diff --git a/rust/frontend/tests/callback/src/bin/string.rs b/rust/frontend/tests/callback/src/bin/string.rs index eafee31796bd..3b2ad65a2f45 100644 --- a/rust/frontend/tests/callback/src/bin/string.rs +++ b/rust/frontend/tests/callback/src/bin/string.rs @@ -1,31 +1,32 @@ -#![feature(extern_crate_item_prelude, try_from)] #![allow(unused_imports)] #[macro_use] extern crate tvm_frontend as tvm; use std::convert::TryInto; -use tvm::*; +use tvm::{errors::Error, *}; // FIXME fn main() { register_global_func! { - fn concate_str(args: &[TVMArgValue]) -> Result { + fn concate_str(args: &[TVMArgValue]) -> Result { let mut ret = "".to_string(); for arg in args.iter() { - let val: String = arg.try_into()?; - ret += val.as_str(); + let val: &str = arg.try_into()?; + ret += val; } Ok(TVMRetValue::from(ret)) } } + let a = std::ffi::CString::new("a").unwrap(); + let b = std::ffi::CString::new("b").unwrap(); + let c = std::ffi::CString::new("c").unwrap(); let mut registered = function::Builder::default(); - registered.get_function("concate_str", true); + registered.get_function("concate_str"); assert!(registered.func.is_some()); - let a = "a".to_string(); - let b = "b".to_string(); - let c = "c".to_string(); let ret: String = registered - .args(&[a, b, c]) + .arg(&a) + .arg(&b) + .arg(&c) .invoke() .unwrap() .try_into() diff --git a/rust/runtime/.gitignore b/rust/runtime/.gitignore deleted file mode 100644 index 230ab66104df..000000000000 --- a/rust/runtime/.gitignore +++ /dev/null @@ -1,3 +0,0 @@ -Cargo.lock -target/ -**/*.rs.bk diff --git a/rust/runtime/Cargo.toml b/rust/runtime/Cargo.toml index d48c0d98c051..ae73ae721224 100644 --- a/rust/runtime/Cargo.toml +++ b/rust/runtime/Cargo.toml @@ -15,15 +15,15 @@ sgx = ["nom/alloc"] [dependencies] bounded-spsc-queue = "0.4.0" -error-chain = { version = "0.12.0", default-features = false } +failure = "0.1.5" itertools = "0.7.8" lazy_static = "1.1.0" -ndarray = "0.11.2" +ndarray="0.12.1" nom = {version = "4.0.0", default-features = false } serde = "1.0.59" serde_derive = "1.0.79" serde_json = "1.0.17" -tvm-common = { version = "0.1.0", path = "../common/", features = ["runtime"] } +tvm-common = { version = "0.1.0", path = "../common/" } [target.'cfg(not(target_env = "sgx"))'.dependencies] num_cpus = "1.8.0" diff --git a/rust/runtime/src/allocator.rs b/rust/runtime/src/allocator.rs index 5f77037e25f3..0514dce2b0a7 100644 --- a/rust/runtime/src/allocator.rs +++ b/rust/runtime/src/allocator.rs @@ -1,9 +1,7 @@ #[cfg(target_env = "sgx")] -use alloc::alloc::{self, Layout}; +use alloc::alloc::{self, Layout, LayoutErr}; #[cfg(not(target_env = "sgx"))] -use std::alloc::{self, Layout}; - -use crate::errors::*; +use std::alloc::{self, Layout, LayoutErr}; const DEFAULT_ALIGN_BYTES: usize = 4; @@ -15,7 +13,7 @@ pub struct Allocation { impl Allocation { /// Allocates a chunk of memory of `size` bytes with optional alignment. - pub fn new(size: usize, align: Option) -> Result { + pub fn new(size: usize, align: Option) -> Result { let alignment = align.unwrap_or(DEFAULT_ALIGN_BYTES); let layout = Layout::from_size_align(size, alignment)?; let ptr = unsafe { alloc::alloc(layout.clone()) }; diff --git a/rust/runtime/src/array.rs b/rust/runtime/src/array.rs index 5c49515a0da3..3bb02f12c866 100644 --- a/rust/runtime/src/array.rs +++ b/rust/runtime/src/array.rs @@ -1,23 +1,17 @@ -use std::{ - any::TypeId, - convert::TryFrom, - mem, - ops::{Deref, DerefMut}, - os::raw::{c_int, c_void}, - ptr, slice, -}; +use std::{convert::TryFrom, mem, os::raw::c_void, ptr, slice}; +use failure::Error; use ndarray; - -use crate::{ - allocator::Allocation, - errors::*, - ffi::runtime::{ +use tvm_common::{ + array::{DataType, TVMContext}, + ffi::{ DLContext, DLDataType, DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, - DLDataTypeCode_kDLUInt, DLDeviceType_kDLCPU, DLTensor as _DLTensor, + DLDataTypeCode_kDLUInt, DLTensor, }, }; +use crate::allocator::Allocation; + /// A `Storage` is a container which holds `Tensor` data. #[derive(PartialEq)] pub enum Storage<'a> { @@ -29,7 +23,7 @@ pub enum Storage<'a> { } impl<'a> Storage<'a> { - pub fn new(size: usize, align: Option) -> Result> { + pub fn new(size: usize, align: Option) -> Result, Error> { Ok(Storage::Owned(Allocation::new(size, align)?)) } @@ -237,6 +231,27 @@ impl<'a> Tensor<'a> { byte_offset: 0, } } + + pub(crate) fn as_dltensor(&self, flatten: bool) -> DLTensor { + assert!(!flatten || self.is_contiguous()); + DLTensor { + data: unsafe { self.data.as_mut_ptr().offset(self.byte_offset) } as *mut c_void, + ctx: DLContext::from(&self.ctx), + ndim: if flatten { 1 } else { self.shape.len() } as i32, + dtype: DLDataType::from(&self.dtype), + shape: if flatten { + &self.size as *const _ as *mut i64 + } else { + self.shape.as_ptr() + } as *mut i64, + strides: if flatten || self.is_contiguous() { + ptr::null_mut() + } else { + self.strides.as_ref().unwrap().as_ptr() + } as *mut i64, + byte_offset: 0, + } + } } /// Conversions to `ndarray::Array` from `Tensor`, if the types match. @@ -244,7 +259,7 @@ macro_rules! impl_ndarray_try_from_tensor { ($type:ty, $dtype:expr) => { impl<'a, 't> TryFrom<&'a Tensor<'t>> for ndarray::ArrayD<$type> { type Error = Error; - fn try_from(tensor: &'a Tensor) -> Result> { + fn try_from(tensor: &'a Tensor) -> Result, Error> { ensure!( tensor.dtype == $dtype, "Cannot convert Tensor with dtype {:?} to ndarray", @@ -263,120 +278,9 @@ macro_rules! impl_ndarray_try_from_tensor { }; } -impl_ndarray_try_from_tensor!(i32, DTYPE_INT32); -impl_ndarray_try_from_tensor!(u32, DTYPE_UINT32); -impl_ndarray_try_from_tensor!(f32, DTYPE_FLOAT32); -impl_ndarray_try_from_tensor!(f64, DTYPE_FLOAT64); - -pub struct DLTensor { - pub(crate) inner: _DLTensor, -} - -impl Deref for DLTensor { - type Target = _DLTensor; - fn deref(&self) -> &Self::Target { - &self.inner - } -} - -impl DerefMut for DLTensor { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.inner - } -} - -impl DLTensor { - pub(crate) fn new(raw: _DLTensor) -> Self { - Self { inner: raw } - } - - pub(crate) fn from_tensor<'a>(tensor: &'a Tensor, flatten: bool) -> Self { - assert!(!flatten || tensor.is_contiguous()); - Self { - inner: _DLTensor { - data: unsafe { tensor.data.as_mut_ptr().offset(tensor.byte_offset) } as *mut c_void, - ctx: DLContext::from(&tensor.ctx), - ndim: if flatten { 1 } else { tensor.shape.len() } as i32, - dtype: DLDataType::from(&tensor.dtype), - shape: if flatten { - &tensor.size as *const _ as *mut i64 - } else { - tensor.shape.as_ptr() - } as *mut i64, - strides: if flatten || tensor.is_contiguous() { - ptr::null_mut() - } else { - tensor.strides.as_ref().unwrap().as_ptr() - } as *mut i64, - byte_offset: 0, - }, - } - } -} - -impl<'a, 't> From<&'a Tensor<'t>> for DLTensor { - fn from(tensor: &'a Tensor<'t>) -> Self { - DLTensor::from_tensor(tensor, false /* flatten */) - } -} - -impl<'a, 't> From<&'a mut Tensor<'t>> for DLTensor { - fn from(tensor: &'a mut Tensor<'t>) -> Self { - DLTensor::from_tensor(tensor, false /* flatten */) - } -} - -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub struct DataType { - pub(crate) code: usize, - pub(crate) bits: usize, - pub(crate) lanes: usize, -} - -impl DataType { - /// Returns the number of bytes occupied by an element of this `DataType`. - pub fn itemsize(&self) -> usize { - (self.bits * self.lanes) >> 3 - } - - /// Returns whether this `DataType` represents primitive type `T`. - pub fn is_type(&self) -> bool { - if self.lanes != 1 { - return false; - } - let typ = TypeId::of::(); - (typ == TypeId::of::() && self.code == 0 && self.bits == 32) - || (typ == TypeId::of::() && self.code == 0 && self.bits == 64) - || (typ == TypeId::of::() && self.code == 1 && self.bits == 32) - || (typ == TypeId::of::() && self.code == 1 && self.bits == 64) - || (typ == TypeId::of::() && self.code == 2 && self.bits == 32) - || (typ == TypeId::of::() && self.code == 2 && self.bits == 64) - } -} - -impl<'a> From<&'a DataType> for DLDataType { - fn from(dtype: &'a DataType) -> Self { - Self { - code: dtype.code as u8, - bits: dtype.bits as u8, - lanes: dtype.lanes as u16, - } - } -} - -impl From for DataType { - fn from(dtype: DLDataType) -> Self { - Self { - code: dtype.code as usize, - bits: dtype.bits as usize, - lanes: dtype.lanes as usize, - } - } -} - macro_rules! make_dtype_const { ($name: ident, $code: ident, $bits: expr, $lanes: expr) => { - const $name: DataType = DataType { + pub const $name: DataType = DataType { code: $code as usize, bits: $bits, lanes: $lanes, @@ -389,28 +293,20 @@ make_dtype_const!(DTYPE_UINT32, DLDataTypeCode_kDLUInt, 32, 1); // make_dtype_const!(DTYPE_FLOAT16, DLDataTypeCode_kDLFloat, 16, 1); make_dtype_const!(DTYPE_FLOAT32, DLDataTypeCode_kDLFloat, 32, 1); make_dtype_const!(DTYPE_FLOAT64, DLDataTypeCode_kDLFloat, 64, 1); +impl_ndarray_try_from_tensor!(i32, DTYPE_INT32); +impl_ndarray_try_from_tensor!(u32, DTYPE_UINT32); +impl_ndarray_try_from_tensor!(f32, DTYPE_FLOAT32); +impl_ndarray_try_from_tensor!(f64, DTYPE_FLOAT64); -#[derive(Debug, Clone, Copy, PartialEq)] -pub struct TVMContext { - pub(crate) device_type: usize, - pub(crate) device_id: usize, -} - -impl<'a> From<&'a TVMContext> for DLContext { - fn from(ctx: &'a TVMContext) -> Self { - Self { - device_type: ctx.device_type as u32, - device_id: ctx.device_id as i32, - } +impl<'a, 't> From<&'a Tensor<'t>> for DLTensor { + fn from(tensor: &'a Tensor<'t>) -> Self { + Tensor::as_dltensor(tensor, false /* flatten */) } } -impl Default for TVMContext { - fn default() -> Self { - Self { - device_type: DLDeviceType_kDLCPU as usize, - device_id: 0, - } +impl<'a, 't> From<&'a mut Tensor<'t>> for DLTensor { + fn from(tensor: &'a mut Tensor<'t>) -> Self { + Tensor::as_dltensor(tensor, false /* flatten */) } } @@ -463,42 +359,6 @@ macro_rules! impl_tensor_from_ndarray { }; } -/// `From` conversions to `DLTensor` for `ndarray::Array`. -/// Takes a reference to the `ndarray` since `DLTensor` is not owned. -macro_rules! impl_dltensor_from_ndarray { - ($type:ty, $typecode:expr) => { - impl<'a, D: ndarray::Dimension> From<&'a mut ndarray::Array<$type, D>> for DLTensor { - fn from(arr: &'a mut ndarray::Array<$type, D>) -> Self { - DLTensor { - inner: _DLTensor { - data: arr.as_mut_ptr() as *mut c_void, - ctx: DLContext { - device_type: DLDeviceType_kDLCPU, - device_id: 0, - }, - ndim: arr.ndim() as c_int, - dtype: DLDataType { - code: $typecode as u8, - bits: 8 * mem::size_of::<$type>() as u8, - lanes: 1, - }, - shape: arr.shape().as_ptr() as *const i64 as *mut i64, - strides: arr.strides().as_ptr() as *const isize as *mut i64, - byte_offset: 0, - }, - } - } - } - }; -} - -impl_dltensor_from_ndarray!(f32, DLDataTypeCode_kDLFloat); -impl_dltensor_from_ndarray!(f64, DLDataTypeCode_kDLFloat); -impl_dltensor_from_ndarray!(i32, DLDataTypeCode_kDLInt); -impl_dltensor_from_ndarray!(i64, DLDataTypeCode_kDLInt); -impl_dltensor_from_ndarray!(u32, DLDataTypeCode_kDLUInt); -impl_dltensor_from_ndarray!(u64, DLDataTypeCode_kDLUInt); - impl_tensor_from_ndarray!(f32, DLDataTypeCode_kDLFloat); impl_tensor_from_ndarray!(f64, DLDataTypeCode_kDLFloat); impl_tensor_from_ndarray!(i32, DLDataTypeCode_kDLInt); diff --git a/rust/runtime/src/errors.rs b/rust/runtime/src/errors.rs index cf7723034882..26a8961697c3 100644 --- a/rust/runtime/src/errors.rs +++ b/rust/runtime/src/errors.rs @@ -1,36 +1,19 @@ -#[cfg(target_env = "sgx")] -use alloc::alloc; -#[cfg(not(target_env = "sgx"))] -use std::alloc; -use std::num; - -use crate::common::errors as common_errors; -use ndarray; -use serde_json; - -error_chain! { - errors { - GraphFormatError(msg: String) { - description("unable to load graph") - display("could not load graph json: {}", msg) - } - - LoadGraphParamsError(msg: String) { - description("unable to load graph params") - display("could not load graph params: {}", msg) - } - } - foreign_links { - Alloc(alloc::AllocErr); - GraphDeserialize(serde_json::Error); - ParseInt(num::ParseIntError); - ShapeError(ndarray::ShapeError); - CommonError(common_errors::Error); - } +#[derive(Debug, Fail)] +pub enum GraphFormatError { + #[fail(display = "Could not parse graph json")] + Parse(#[fail(cause)] failure::Error), + #[fail(display = "Could not parse graph params")] + Params, + #[fail(display = "{} is missing attr: {}", 0, 1)] + MissingAttr(String, String), + #[fail(display = "Missing field: {}", 0)] + MissingField(&'static str), + #[fail(display = "Invalid DLType: {}", 0)] + InvalidDLType(String), } -impl From for Error { - fn from(_err: alloc::LayoutErr) -> Error { - Error::from_kind(ErrorKind::Msg("Layout error".to_string())) - } +#[derive(Debug, Fail)] +#[fail(display = "SGX error: 0x{:x}", code)] +pub struct SgxError { + pub code: u32, } diff --git a/rust/runtime/src/graph.rs b/rust/runtime/src/graph.rs index 0d5e281f3f77..6e00d9c7a14c 100644 --- a/rust/runtime/src/graph.rs +++ b/rust/runtime/src/graph.rs @@ -1,16 +1,17 @@ use std::{cmp, collections::HashMap, convert::TryFrom, iter::FromIterator, mem, str}; +use failure::Error; use nom::{alpha1, digit1, le_i32, le_i64, le_u16, le_u32, le_u64, le_u8, types::CompleteStr}; use serde; use serde_json; - -use super::{DLTensor, DataType, Module, Storage, TVMContext, Tensor}; -use crate::{ - common::value::TVMArgValue, - errors::{Error, ErrorKind, Result}, - ffi::runtime::{DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLDataTypeCode_kDLUInt}, +use tvm_common::{ + array::{DataType, TVMContext}, + ffi::{DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLDataTypeCode_kDLUInt, DLTensor}, + TVMArgValue, }; +use crate::{errors::GraphFormatError, Module, Storage, Tensor}; + // @see `kTVMNDArrayMagic` in `ndarray.h` const _NDARRAY_MAGIC: u64 = 0xDD5E40F096B4A13F; // @see `kTVMNDArrayListMagic` in `graph_runtime.h` @@ -41,28 +42,26 @@ pub struct Entry { } impl Graph { - fn entry_index(&self, entry: &Entry) -> Result { + fn entry_index(&self, entry: &Entry) -> Result { self.node_row_ptr .as_ref() .map(|nrp| nrp[entry.id] + entry.index) - .ok_or("Missing node_row_ptr.".into()) + .ok_or_else(|| GraphFormatError::MissingField("node_row_ptr")) } /// Attempt to deserialize a JSON attribute to a type `T`. - fn get_attr(&self, attr: &str) -> Result { + fn get_attr(&self, attr: &str) -> Result { Ok(serde_json::from_value::( self.attrs .as_ref() - .ok_or(ErrorKind::GraphFormatError( - "Missing graph attrs".to_string(), - ))? + .ok_or(GraphFormatError::MissingField("attrs"))? .get(attr) - .ok_or(ErrorKind::GraphFormatError(format!( - "Missing {} attr", - attr - )))? + .ok_or_else(|| { + GraphFormatError::MissingAttr("graph".to_string(), attr.to_string()) + })? .to_owned(), - )?) + ) + .map_err(|err| GraphFormatError::Parse(err.into()))?) } } @@ -81,39 +80,31 @@ struct NodeAttrs { flatten_data: bool, } +macro_rules! get_node_attr { + ($node:expr, $attrs:ident, $attr:literal) => { + $attrs + .get($attr) + .ok_or_else(|| GraphFormatError::MissingAttr($node.to_owned(), $attr.to_owned())) + }; +} + impl Node { - fn parse_attrs(&self) -> Result { + fn parse_attrs(&self) -> Result { let attrs = self .attrs .as_ref() - .ok_or(format!("Missing node.attrs for `{}`", self.name))?; - let func_name = attrs - .get("func_name") - .ok_or(format!("Node `{}` is missing attrs.func_name", self.name))? - .to_string(); - let num_outputs = attrs - .get("num_outputs") - .ok_or(format!("Node `{}` is missing attrs.num_outputs", self.name))? - .parse::()?; - let flatten_data = attrs - .get("flatten_data") - .ok_or(format!( - "Node `{}` is missing attrs.flatten_data", - self.name - ))? - .parse::()? - == 1; + .ok_or_else(|| GraphFormatError::MissingAttr(self.name.clone(), "attrs".to_owned()))?; Ok(NodeAttrs { - func_name, - num_outputs, - flatten_data, + func_name: get_node_attr!(self.name, attrs, "func_name")?.to_owned(), + num_outputs: get_node_attr!(self.name, attrs, "num_outputs")?.parse::()?, + flatten_data: get_node_attr!(self.name, attrs, "flatten_data")?.parse::()? == 1, }) } } impl<'a> TryFrom<&'a String> for Graph { type Error = Error; - fn try_from(graph_json: &String) -> Result { + fn try_from(graph_json: &String) -> Result { let graph = serde_json::from_str(graph_json)?; Ok(graph) } @@ -121,7 +112,7 @@ impl<'a> TryFrom<&'a String> for Graph { impl<'a> TryFrom<&'a str> for Graph { type Error = Error; - fn try_from(graph_json: &'a str) -> Result { + fn try_from(graph_json: &'a str) -> Result { let graph = serde_json::from_str(graph_json)?; Ok(graph) } @@ -161,7 +152,7 @@ pub struct GraphExecutor<'m, 't> { unsafe impl<'m, 't> Send for GraphExecutor<'m, 't> {} impl<'m, 't> GraphExecutor<'m, 't> { - pub fn new(graph: Graph, lib: &'m M) -> Result { + pub fn new(graph: Graph, lib: &'m M) -> Result { let tensors = Self::setup_storages(&graph)?; Ok(GraphExecutor { op_execs: Self::setup_op_execs(&graph, lib, &tensors)?, @@ -178,7 +169,7 @@ impl<'m, 't> GraphExecutor<'m, 't> { } /// Allocates `Storages` for each `storage_id` and returns `Tensor`s to hold each output. - fn setup_storages<'a>(graph: &'a Graph) -> Result>> { + fn setup_storages<'a>(graph: &'a Graph) -> Result>, Error> { let storage_ids = graph.get_attr::<(String, Vec)>("storage_id")?.1; let shapes = graph.get_attr::<(String, Vec>)>("shape")?.1; let dtypes = graph @@ -189,18 +180,15 @@ impl<'m, 't> GraphExecutor<'m, 't> { if let Ok((_, dtype)) = tvm_str_to_type(CompleteStr(dltype)) { Ok(dtype) } else { - Err(ErrorKind::GraphFormatError( - format!("Invalid dltype: {}", dltype).to_string(), - ) - .into()) + Err(GraphFormatError::InvalidDLType(dltype.to_string())) } }) - .collect::>>()?; + .collect::, GraphFormatError>>()?; - let align = dtypes.iter().map(|dtype| dtype.bits as usize).max(); + let align = dtypes.iter().map(|dtype| dtype.bits() as usize).max(); let mut storage_num_bytes = vec![0usize; *storage_ids.iter().max().unwrap_or(&1) + 1]; for (i, &storage_id) in storage_ids.iter().enumerate() { - let dtype_size = dtypes[i].bits * dtypes[i].lanes >> 3; + let dtype_size = dtypes[i].bits() * dtypes[i].lanes() >> 3; let nbytes = dtype_size * shapes[i].iter().product::() as usize; storage_num_bytes[storage_id] = cmp::max(nbytes, storage_num_bytes[storage_id]); } @@ -208,7 +196,7 @@ impl<'m, 't> GraphExecutor<'m, 't> { let mut storages: Vec = storage_num_bytes .into_iter() .map(|nbytes| Storage::new(nbytes, align)) - .collect::>>()?; + .collect::, Error>>()?; let tensors = izip!(storage_ids, shapes, dtypes) .map(|(storage_id, shape, dtype)| { @@ -233,7 +221,7 @@ impl<'m, 't> GraphExecutor<'m, 't> { graph: &Graph, lib: &'m M, tensors: &Vec>, - ) -> Result>> { + ) -> Result>, Error> { ensure!(graph.node_row_ptr.is_some(), "Missing node_row_ptr."); let node_row_ptr = graph.node_row_ptr.as_ref().unwrap(); @@ -251,9 +239,10 @@ impl<'m, 't> GraphExecutor<'m, 't> { continue; } - let func = lib - .get_function(&attrs.func_name) - .ok_or(format!("Missing function {}", attrs.func_name))?; + let func = lib.get_function(&attrs.func_name).ok_or(format_err!( + "Library is missing function {}", + attrs.func_name + ))?; let arg_indices = node .inputs .iter() @@ -264,19 +253,19 @@ impl<'m, 't> GraphExecutor<'m, 't> { .map(|idx| { let tensor = &tensors[idx?]; Ok(if attrs.flatten_data { - DLTensor::from_tensor(tensor, true /* flatten */) + Tensor::as_dltensor(tensor, true /* flatten */) } else { DLTensor::from(tensor) }) }) - .collect::>>() + .collect::, Error>>() .unwrap(); let op: Box = box move || { let args = dl_tensors .iter() .map(|t| t.into()) .collect::>(); - func(args.as_slice()); + func(args.as_slice()).unwrap(); }; op_execs.push(op); } @@ -344,7 +333,7 @@ impl<'m, 't> GraphExecutor<'m, 't> { } } -/// Converts a string to TVM DLDataTypeCode. @see `String2TVMType` in packed_func.h +// Converts a string to TVM DLDataTypeCode. @see `String2TVMType` in packed_func.h named!( tvm_str_to_type, do_parse!( @@ -367,7 +356,7 @@ named!( ) ); -/// Converts a bytes to String. +// Converts a bytes to String. named!( name, map_res!(length_bytes!(le_u64), |b: &[u8]| String::from_utf8( @@ -375,7 +364,7 @@ named!( )) ); -/// Parses a TVMContext +// Parses a TVMContext named!( tvm_ctx<&[u8], TVMContext>, do_parse!( @@ -385,7 +374,7 @@ named!( ) ); -/// Parses a DataType +// Parses a DataType named!( data_type<&[u8], DataType>, do_parse!( @@ -396,7 +385,7 @@ named!( ) ); -/// Parses a Tensor from a TVM array file. +// Parses a Tensor from a TVM array file. named!( tensor, do_parse!( @@ -420,7 +409,7 @@ named!( ) ); -/// Parses a graph params dict from a params binary file. +// Parses a graph params dict from a params binary file. named!( parse_param_dict>, do_parse!( @@ -433,17 +422,15 @@ named!( ); /// Loads a param dict saved using `nnvm.compiler.save_param_dict`. -pub fn load_param_dict(bytes: &[u8]) -> Result> { +pub fn load_param_dict(bytes: &[u8]) -> Result, GraphFormatError> { if let Ok((remaining_bytes, param_dict)) = parse_param_dict(bytes) { - if remaining_bytes.len() > 0 { - bail!(ErrorKind::LoadGraphParamsError("extra input".to_string())) - } else { + if remaining_bytes.len() == 0 { Ok(param_dict) + } else { + Err(GraphFormatError::Params) } } else { - bail!(ErrorKind::LoadGraphParamsError( - "invalid parameters file".to_string() - )) + Err(GraphFormatError::Params) } } diff --git a/rust/runtime/src/lib.rs b/rust/runtime/src/lib.rs index da030bc4be65..848db27ecdcc 100644 --- a/rust/runtime/src/lib.rs +++ b/rust/runtime/src/lib.rs @@ -14,7 +14,6 @@ allocator_api, box_syntax, fn_traits, - try_from, unboxed_closures, vec_remove_item )] @@ -25,7 +24,7 @@ extern crate bounded_spsc_queue; #[cfg(target_env = "sgx")] extern crate core; #[macro_use] -extern crate error_chain; +extern crate failure; #[macro_use] extern crate itertools; #[macro_use] @@ -39,36 +38,45 @@ extern crate serde; #[macro_use] extern crate serde_derive; extern crate serde_json; -extern crate tvm_common as common; +extern crate tvm_common; mod allocator; mod array; pub mod errors; -mod module; -#[macro_use] -mod packed_func; mod graph; +mod module; #[cfg(target_env = "sgx")] #[macro_use] pub mod sgx; mod threading; mod workspace; -pub use crate::common::{errors::*, ffi, TVMArgValue, TVMRetValue}; - -pub use self::{ - array::*, errors::*, graph::*, module::*, packed_func::*, threading::*, workspace::*, +pub use tvm_common::{ + call_packed, + errors::*, + ffi::{self, DLTensor}, + packed_func::{self, *}, + TVMArgValue, TVMRetValue, }; -#[cfg(target_env = "sgx")] -use self::sgx::ocall_packed_func; +pub use self::{array::*, errors::*, graph::*, module::*, threading::*, workspace::*}; + +lazy_static! { + static ref LAST_ERROR: std::sync::RwLock> = + std::sync::RwLock::new(None); +} #[no_mangle] pub extern "C" fn TVMAPISetLastError(cmsg: *const i8) { - #[cfg(not(target_env = "sgx"))] - unsafe { - panic!(std::ffi::CStr::from_ptr(cmsg).to_str().unwrap()); - } + *LAST_ERROR.write().unwrap() = Some(unsafe { std::ffi::CStr::from_ptr(cmsg) }); #[cfg(target_env = "sgx")] ocall_packed!("__sgx_set_last_error__", cmsg); } + +#[no_mangle] +pub extern "C" fn TVMGetLastError() -> *const std::os::raw::c_char { + match *LAST_ERROR.read().unwrap() { + Some(err) => err.as_ptr(), + None => std::ptr::null(), + } +} diff --git a/rust/runtime/src/module.rs b/rust/runtime/src/module.rs index 8e6f7d665dd4..636c4e8ff5cf 100644 --- a/rust/runtime/src/module.rs +++ b/rust/runtime/src/module.rs @@ -2,29 +2,29 @@ use std::{ collections::HashMap, convert::AsRef, ffi::CStr, os::raw::c_char, string::String, sync::Mutex, }; -use crate::{ - ffi::runtime::BackendPackedCFunc, - packed_func::{wrap_backend_packed_func, PackedFunc}, +use tvm_common::{ + ffi::BackendPackedCFunc, + packed_func::{PackedFunc, TVMArgValue, TVMRetValue, TVMValue}, }; pub trait Module { - fn get_function>(&self, name: S) -> Option; + fn get_function>(&self, name: S) -> Option<&(dyn PackedFunc)>; } pub struct SystemLibModule; lazy_static! { - static ref SYSTEM_LIB_FUNCTIONS: Mutex> = + static ref SYSTEM_LIB_FUNCTIONS: Mutex> = Mutex::new(HashMap::new()); } impl Module for SystemLibModule { - fn get_function>(&self, name: S) -> Option { + fn get_function>(&self, name: S) -> Option<&(dyn PackedFunc)> { SYSTEM_LIB_FUNCTIONS .lock() .unwrap() .get(name.as_ref()) - .map(|func| wrap_backend_packed_func(func.to_owned())) + .map(|f| *f) } } @@ -34,15 +34,42 @@ impl Default for SystemLibModule { } } +// @see `WrapPackedFunc` in `llvm_module.cc`. +pub(super) fn wrap_backend_packed_func( + func_name: String, + func: BackendPackedCFunc, +) -> Box { + box move |args: &[TVMArgValue]| { + let exit_code = func( + args.iter() + .map(|ref arg| arg.value) + .collect::>() + .as_ptr(), + args.iter() + .map(|ref arg| arg.type_code as i32) + .collect::>() + .as_ptr() as *const i32, + args.len() as i32, + ); + if exit_code == 0 { + Ok(TVMRetValue::default()) + } else { + Err(tvm_common::errors::FuncCallError::get_with_context( + func_name.clone(), + )) + } + } +} + #[no_mangle] pub extern "C" fn TVMBackendRegisterSystemLibSymbol( cname: *const c_char, func: BackendPackedCFunc, ) -> i32 { let name = unsafe { CStr::from_ptr(cname).to_str().unwrap() }; - SYSTEM_LIB_FUNCTIONS - .lock() - .unwrap() - .insert(name.to_string(), func); + SYSTEM_LIB_FUNCTIONS.lock().unwrap().insert( + name.to_string(), + &*Box::leak(wrap_backend_packed_func(name.to_string(), func)), + ); return 0; } diff --git a/rust/runtime/src/packed_func.rs b/rust/runtime/src/packed_func.rs deleted file mode 100644 index 2fe0086e9a0d..000000000000 --- a/rust/runtime/src/packed_func.rs +++ /dev/null @@ -1,118 +0,0 @@ -use std::{convert::TryFrom, marker::PhantomData, os::raw::c_void}; - -use super::Tensor; -use crate::ffi::runtime::{ - BackendPackedCFunc, DLTensor as _DLTensor, TVMTypeCode_kArrayHandle, - TVMTypeCode_kNDArrayContainer, TVMValue as _TVMValue, -}; - -use super::DLTensor; -use crate::{ - common::{TVMArgValue, TVMRetValue, TVMTypeCode, TVMValue}, - errors::*, -}; - -pub type PackedFunc = Box TVMRetValue + Send + Sync>; - -/// Calls a packed function and returns a `TVMRetValue`. -/// -/// # Example -/// -/// `call_packed!(my_tvm_func, &mut arg1, &mut arg2)` -#[macro_export] -macro_rules! call_packed { - ($fn:expr, $($args:expr),+) => { - $fn(&[$($args.into(),)+]) - }; - ($fn:expr) => { - $fn(&Vec::new()) - }; -} - -impl<'a> From<&'a DLTensor> for TVMArgValue<'a> { - fn from(arr: &'a DLTensor) -> Self { - let raw = _TVMValue { - v_handle: arr as *const _ as *mut DLTensor as *mut c_void, - }; - TVMArgValue { - value: TVMValue::new(raw), - type_code: TVMTypeCode::kArrayHandle, - lifetime: PhantomData, - } - } -} - -impl<'a> From<&'a mut DLTensor> for TVMArgValue<'a> { - fn from(arr: &'a mut DLTensor) -> Self { - let raw = _TVMValue { - v_handle: arr as *mut _ as *mut c_void, - }; - TVMArgValue { - value: TVMValue::new(raw), - type_code: TVMTypeCode::kArrayHandle, - lifetime: PhantomData, - } - } -} - -impl<'a> TryFrom> for Tensor<'a> { - type Error = Error; - fn try_from(val: TVMArgValue<'a>) -> Result { - ensure!( - val.type_code == TVMTypeCode::kArrayHandle - || val.type_code == TVMTypeCode::kNDArrayContainer, - "Could not downcast arg. Expected `{}` or `{}`, but got `{}`", - TVMTypeCode::kArrayHandle, - TVMTypeCode::kNDArrayContainer, - val.type_code, - ); - - let dlt = unsafe { *(val.value.v_handle as *mut _DLTensor as *const _DLTensor) }; - Ok(DLTensor::new(dlt).into()) - } -} - -impl<'a, 't> From<&'t Tensor<'a>> for TVMRetValue { - fn from(val: &'t Tensor<'a>) -> Self { - TVMRetValue { - prim_value: 0, - box_value: box DLTensor::from(val), - type_code: TVMTypeCode::kNDArrayContainer, - } - } -} - -impl<'a> TryFrom for Tensor<'a> { - type Error = Error; - fn try_from(ret: TVMRetValue) -> Result { - ensure!( - ret.type_code == TVMTypeCode::kArrayHandle - || ret.type_code == TVMTypeCode::kNDArrayContainer, - "Could not downcast arg. Expected `{}` or `{}`, but got `{}`", - TVMTypeCode_kArrayHandle, - TVMTypeCode_kNDArrayContainer, - ret.type_code, - ); - - let dlt = unsafe { *(ret.prim_value as *mut _DLTensor as *const _DLTensor) }; - Ok(DLTensor::new(dlt).into()) - } -} - -// @see `WrapPackedFunc` in `llvm_module.cc`. -pub(crate) fn wrap_backend_packed_func(func: BackendPackedCFunc) -> PackedFunc { - box move |args: &[TVMArgValue]| { - func( - args.iter() - .map(|ref arg| arg.value.inner) - .collect::>() - .as_ptr(), - args.iter() - .map(|ref arg| arg.type_code as i32) - .collect::>() - .as_ptr() as *const i32, - args.len() as i32, - ); - TVMRetValue::default() - } -} diff --git a/rust/runtime/src/sgx.rs b/rust/runtime/src/sgx.rs index 1edf3ef497e7..42d3aa4aaa7e 100644 --- a/rust/runtime/src/sgx.rs +++ b/rust/runtime/src/sgx.rs @@ -3,18 +3,17 @@ use std::{ os::raw::{c_char, c_int}, }; -use errors::Result; -use ffi::runtime::TVMValue; -use runtime::{threading::sgx_join_threads, SystemLibModule, TVMArgValue, TVMRetValue}; - -pub use runtime::threading::tvm_run_worker as run_worker; +pub use crate::threading::tvm_run_worker as run_worker; +use crate::{threading::sgx_join_threads, SystemLibModule, TVMArgValue, TVMRetValue}; +use errors::SgxError; +use ffi::TVMValue; #[macro_export] macro_rules! tvm_ocall { ($func: expr) => { match $func { 0 => Ok(()), - err => Err(format!("SGX error: {}", err)), + code => Err(SgxError { code }), } }; } @@ -33,7 +32,10 @@ extern "C" { ) -> SgxStatus; } -pub fn ocall_packed_func>(fn_name: S, args: &[TVMArgValue]) -> Result { +pub fn ocall_packed_func>( + fn_name: S, + args: &[TVMArgValue], +) -> Result { let mut ret_val = TVMValue { v_int64: 0 }; let ret_type_code = 0i64; unsafe { @@ -58,11 +60,11 @@ pub fn ocall_packed_func>(fn_name: S, args: &[TVMArgValue]) -> Res #[macro_export] macro_rules! ocall_packed { ($fn_name:expr, $($args:expr),+) => { - ocall_packed_func($fn_name, &[$($args.into(),)+]) + $crate::sgx::ocall_packed_func($fn_name, &[$($args.into(),)+]) .expect(concat!("Error calling `", $fn_name, "`")) }; ($fn_name:expr) => { - ocall_packed_func($fn_name, &Vec::new()) + $crate::sgx::ocall_packed_func($fn_name, &Vec::new()) .expect(concat!("Error calling `", $fn_name, "`")) } } diff --git a/rust/runtime/src/threading.rs b/rust/runtime/src/threading.rs index 38f4b7d23f0f..408c0b491bb0 100644 --- a/rust/runtime/src/threading.rs +++ b/rust/runtime/src/threading.rs @@ -1,7 +1,7 @@ use std::{ os::raw::{c_int, c_void}, sync::{ - atomic::{AtomicUsize, Ordering, ATOMIC_USIZE_INIT}, + atomic::{AtomicUsize, Ordering}, Arc, Barrier, }, }; @@ -18,11 +18,10 @@ use std::{ use std::{collections::VecDeque, ptr, sync::Mutex}; use bounded_spsc_queue::{self, Producer}; - -use crate::{errors::*, ffi::runtime::TVMParallelGroupEnv}; +use tvm_common::ffi::TVMParallelGroupEnv; #[cfg(target_env = "sgx")] -use super::{sgx::ocall_packed_func, TVMArgValue, TVMRetValue}; +use super::{TVMArgValue, TVMRetValue}; type FTVMParallelLambda = extern "C" fn(task_id: usize, penv: *const TVMParallelGroupEnv, cdata: *const c_void) -> i32; @@ -62,12 +61,11 @@ impl Job { } /// Waits for all tasks in this `Job` to be completed. - fn wait(&self) -> Result<()> { + fn wait(&self) { while self.pending.load(Ordering::Acquire) > 0 { #[cfg(not(target_env = "sgx"))] thread::yield_now(); } - Ok(()) } } @@ -161,7 +159,7 @@ impl ThreadPool { } tasks.pop().unwrap()(); - job.wait().unwrap(); + job.wait(); } fn run_worker(queue: Consumer) { @@ -251,7 +249,7 @@ pub extern "C" fn TVMBackendParallelLaunch( cb: cb, cdata: cdata, req_num_tasks: num_task, - pending: Arc::new(ATOMIC_USIZE_INIT), + pending: Arc::new(AtomicUsize::new(0)), }); }); } @@ -273,7 +271,7 @@ pub(crate) fn sgx_join_threads() { cb: poison_pill, cdata: ptr::null(), req_num_tasks: 0, - pending: Arc::new(ATOMIC_USIZE_INIT), + pending: Arc::new(AtomicUsize::new(0)), }); }); ocall_packed!("__sgx_thread_group_join__", 0); @@ -322,8 +320,8 @@ mod tests { #[test] fn test_parallel_launch() { TVMBackendParallelLaunch(flambda, ptr::null(), 6); - let counter = ATOMIC_USIZE_INIT; - let task_ids_sum = ATOMIC_USIZE_INIT; + let counter = AtomicUsize::new(0); + let task_ids_sum = AtomicUsize::new(0); let cdata = (counter, task_ids_sum); let num_tasks = 3; TVMBackendParallelLaunch(flambda, &cdata as *const _ as *const c_void, num_tasks); diff --git a/rust/runtime/src/workspace.rs b/rust/runtime/src/workspace.rs index a12a27e4c47c..1e29ec179bca 100644 --- a/rust/runtime/src/workspace.rs +++ b/rust/runtime/src/workspace.rs @@ -4,8 +4,9 @@ use std::{ ptr, }; -use super::allocator::Allocation; -use crate::errors::*; +use failure::Error; + +use crate::allocator::Allocation; const WS_ALIGN: usize = 64; // taken from `kTempAllocaAlignment` in `device_api.h` @@ -24,13 +25,13 @@ impl WorkspacePool { } } - fn alloc_new(&mut self, size: usize) -> Result<*mut u8> { + fn alloc_new(&mut self, size: usize) -> Result<*mut u8, Error> { self.workspaces.push(Allocation::new(size, Some(WS_ALIGN))?); self.in_use.push(self.workspaces.len() - 1); Ok(self.workspaces[self.workspaces.len() - 1].as_mut_ptr()) } - fn alloc(&mut self, size: usize) -> Result<*mut u8> { + fn alloc(&mut self, size: usize) -> Result<*mut u8, Error> { if self.free.len() == 0 { return self.alloc_new(size); } @@ -60,7 +61,7 @@ impl WorkspacePool { } } - fn free(&mut self, ptr: *mut u8) -> Result<()> { + fn free(&mut self, ptr: *mut u8) -> Result<(), Error> { let mut ws_idx = None; for i in 0..self.in_use.len() { let idx = self.in_use[i]; @@ -72,7 +73,7 @@ impl WorkspacePool { } Ok(self .free - .push(ws_idx.ok_or("Tried to free nonexistent workspace.")?)) + .push(ws_idx.ok_or(format_err!("Tried to free nonexistent workspace."))?)) } } diff --git a/rust/runtime/tests/test_nnvm/Cargo.toml b/rust/runtime/tests/test_nnvm/Cargo.toml index 14d0b3961ad3..259af2341d95 100644 --- a/rust/runtime/tests/test_nnvm/Cargo.toml +++ b/rust/runtime/tests/test_nnvm/Cargo.toml @@ -5,7 +5,7 @@ license = "Apache-2.0" authors = ["TVM Contributors"] [dependencies] -ndarray = "0.11.2" +ndarray="0.12.1" serde = "1.0.59" serde_json = "1.0.17" tvm-runtime = { path = "../../" } diff --git a/rust/runtime/tests/test_tvm_basic/Cargo.toml b/rust/runtime/tests/test_tvm_basic/Cargo.toml index 2a753b430c47..561215daf0a7 100644 --- a/rust/runtime/tests/test_tvm_basic/Cargo.toml +++ b/rust/runtime/tests/test_tvm_basic/Cargo.toml @@ -5,7 +5,7 @@ license = "Apache-2.0" authors = ["TVM Contributors"] [dependencies] -ndarray = "0.11.2" +ndarray="0.12.1" tvm-runtime = { path = "../../" } [build-dependencies] diff --git a/rust/runtime/tests/test_tvm_basic/src/main.rs b/rust/runtime/tests/test_tvm_basic/src/main.rs index f14fbec8c439..621315ddd34f 100644 --- a/rust/runtime/tests/test_tvm_basic/src/main.rs +++ b/rust/runtime/tests/test_tvm_basic/src/main.rs @@ -17,6 +17,6 @@ fn main() { let mut a_dl: DLTensor = (&mut a).into(); let mut b_dl: DLTensor = (&mut b).into(); let mut c_dl: DLTensor = (&mut c).into(); - call_packed!(add, &mut a_dl, &mut b_dl, &mut c_dl); + call_packed!(add, &mut a_dl, &mut b_dl, &mut c_dl).unwrap(); assert!(c.all_close(&e, 1e-8f32)); } diff --git a/tests/scripts/task_rust.sh b/tests/scripts/task_rust.sh index be0181b4d95b..6e174208f2c7 100755 --- a/tests/scripts/task_rust.sh +++ b/tests/scripts/task_rust.sh @@ -14,11 +14,11 @@ cargo fmt -- --check # test common cd $RUST_DIR/common -cargo build --features runtime -cargo test --features runtime --tests +cargo build +cargo test --tests -cargo build --features frontend -cargo test --features frontend --tests +cargo build --features bindings +cargo test --features bindings --tests # test runtime cd $RUST_DIR/runtime