From 1f38ac71b8e12ef7f738b8ce15b7568a9b1cd4c7 Mon Sep 17 00:00:00 2001 From: Nick Hynes Date: Sun, 3 Feb 2019 01:20:56 +0000 Subject: [PATCH 01/17] Begin unify Rust lib --- rust/common/Cargo.toml | 7 +- rust/common/{tvm-sys => }/build.rs | 17 +- rust/common/src/c_runtime_api.rs | 1072 ++++++++---------------- rust/common/src/errors.rs | 41 +- rust/common/src/lib.rs | 23 +- rust/common/src/packed_func.rs | 329 ++++++++ rust/common/src/ty.rs | 144 ---- rust/common/src/value.rs | 589 ++----------- rust/common/tvm-sys/Cargo.toml | 9 - rust/common/tvm-sys/src/lib.rs | 9 - rust/frontend/Cargo.toml | 2 +- rust/frontend/src/bytearray.rs | 8 +- rust/frontend/src/context.rs | 74 +- rust/frontend/src/errors.rs | 5 +- rust/frontend/src/function.rs | 214 ++--- rust/frontend/src/lib.rs | 48 +- rust/frontend/src/module.rs | 44 +- rust/frontend/src/ndarray.rs | 36 +- rust/frontend/src/ty.rs | 150 ---- rust/frontend/src/value.rs | 234 ++---- rust/frontend/tests/basics/src/main.rs | 10 +- rust/runtime/Cargo.toml | 2 +- rust/runtime/src/array.rs | 13 +- rust/runtime/src/graph.rs | 8 +- rust/runtime/src/packed_func.rs | 16 +- 25 files changed, 1053 insertions(+), 2051 deletions(-) rename rust/common/{tvm-sys => }/build.rs (56%) create mode 100644 rust/common/src/packed_func.rs delete mode 100644 rust/common/tvm-sys/Cargo.toml delete mode 100644 rust/common/tvm-sys/src/lib.rs delete mode 100644 rust/frontend/src/ty.rs diff --git a/rust/common/Cargo.toml b/rust/common/Cargo.toml index bcba5ad62fc9..dc1592b62952 100644 --- a/rust/common/Cargo.toml +++ b/rust/common/Cargo.toml @@ -5,9 +5,10 @@ authors = ["TVM Contributors"] license = "Apache-2.0" [features] -runtime = [] -frontend = ["tvm-sys"] +bindings = [] [dependencies] error-chain = { version = "0.12.0", default-features = false } -tvm-sys = { version = "0.1.0", path = "tvm-sys", optional = true } + +[build-dependencies] +bindgen = "0.37.4" diff --git a/rust/common/tvm-sys/build.rs b/rust/common/build.rs similarity index 56% rename from rust/common/tvm-sys/build.rs rename to rust/common/build.rs index f842043a1d16..1a45854bc4e3 100644 --- a/rust/common/tvm-sys/build.rs +++ b/rust/common/build.rs @@ -3,10 +3,13 @@ extern crate bindgen; use std::path::PathBuf; fn main() { - println!("cargo:rerun-if-env-changed=TVM_HOME"); - println!("cargo:rustc-link-lib=dylib=tvm_runtime"); - println!("cargo:rustc-link-search={}/build", env!("TVM_HOME")); - let bindings = bindgen::Builder::default() + if cfg!(feature = "bindings") { + println!("cargo:rerun-if-env-changed=TVM_HOME"); + println!("cargo:rustc-link-lib=dylib=tvm_runtime"); + println!("cargo:rustc-link-search={}/build", env!("TVM_HOME")); + } + + bindgen::Builder::default() .header(format!( "{}/include/tvm/runtime/c_runtime_api.h", env!("TVM_HOME") @@ -17,9 +20,7 @@ fn main() { .derive_partialeq(true) .derive_eq(true) .generate() - .expect("unable to generate bindings"); - - bindings - .write_to_file(PathBuf::from("src/bindgen.rs")) + .expect("unable to generate bindings") + .write_to_file(PathBuf::from("src/c_runtime_api.rs")) .expect("can not write the bindings!"); } diff --git a/rust/common/src/c_runtime_api.rs b/rust/common/src/c_runtime_api.rs index 6facf9ca274f..28e2518d7557 100644 --- a/rust/common/src/c_runtime_api.rs +++ b/rust/common/src/c_runtime_api.rs @@ -1,399 +1,217 @@ -/* automatically generated by rust-bindgen for TVM revision 6292c78 */ +/* automatically generated by rust-bindgen */ -pub const TVM_VERSION: &'static [u8; 8usize] = b"0.5.dev\0"; -pub const DLPACK_VERSION: u32 = 8; -pub const _STDINT_H: u32 = 1; -pub const _FEATURES_H: u32 = 1; -pub const _DEFAULT_SOURCE: u32 = 1; -pub const __USE_ISOC11: u32 = 1; -pub const __USE_ISOC99: u32 = 1; -pub const __USE_ISOC95: u32 = 1; -pub const __USE_POSIX_IMPLICITLY: u32 = 1; -pub const _POSIX_SOURCE: u32 = 1; -pub const _POSIX_C_SOURCE: u32 = 200809; -pub const __USE_POSIX: u32 = 1; -pub const __USE_POSIX2: u32 = 1; -pub const __USE_POSIX199309: u32 = 1; -pub const __USE_POSIX199506: u32 = 1; -pub const __USE_XOPEN2K: u32 = 1; -pub const __USE_XOPEN2K8: u32 = 1; -pub const _ATFILE_SOURCE: u32 = 1; -pub const __USE_MISC: u32 = 1; -pub const __USE_ATFILE: u32 = 1; -pub const __USE_FORTIFY_LEVEL: u32 = 0; -pub const _STDC_PREDEF_H: u32 = 1; -pub const __STDC_IEC_559__: u32 = 1; -pub const __STDC_IEC_559_COMPLEX__: u32 = 1; -pub const __STDC_ISO_10646__: u32 = 201505; -pub const __STDC_NO_THREADS__: u32 = 1; -pub const __GNU_LIBRARY__: u32 = 6; -pub const __GLIBC__: u32 = 2; -pub const __GLIBC_MINOR__: u32 = 23; -pub const _SYS_CDEFS_H: u32 = 1; -pub const __WORDSIZE: u32 = 64; -pub const __WORDSIZE_TIME64_COMPAT32: u32 = 1; -pub const __SYSCALL_WORDSIZE: u32 = 64; -pub const _BITS_WCHAR_H: u32 = 1; -pub const INT8_MIN: i32 = -128; -pub const INT16_MIN: i32 = -32768; -pub const INT32_MIN: i32 = -2147483648; -pub const INT8_MAX: u32 = 127; -pub const INT16_MAX: u32 = 32767; -pub const INT32_MAX: u32 = 2147483647; -pub const UINT8_MAX: u32 = 255; -pub const UINT16_MAX: u32 = 65535; -pub const UINT32_MAX: u32 = 4294967295; -pub const INT_LEAST8_MIN: i32 = -128; -pub const INT_LEAST16_MIN: i32 = -32768; -pub const INT_LEAST32_MIN: i32 = -2147483648; -pub const INT_LEAST8_MAX: u32 = 127; -pub const INT_LEAST16_MAX: u32 = 32767; -pub const INT_LEAST32_MAX: u32 = 2147483647; -pub const UINT_LEAST8_MAX: u32 = 255; -pub const UINT_LEAST16_MAX: u32 = 65535; -pub const UINT_LEAST32_MAX: u32 = 4294967295; -pub const INT_FAST8_MIN: i32 = -128; -pub const INT_FAST16_MIN: i64 = -9223372036854775808; -pub const INT_FAST32_MIN: i64 = -9223372036854775808; -pub const INT_FAST8_MAX: u32 = 127; -pub const INT_FAST16_MAX: u64 = 9223372036854775807; -pub const INT_FAST32_MAX: u64 = 9223372036854775807; -pub const UINT_FAST8_MAX: u32 = 255; -pub const UINT_FAST16_MAX: i32 = -1; -pub const UINT_FAST32_MAX: i32 = -1; -pub const INTPTR_MIN: i64 = -9223372036854775808; -pub const INTPTR_MAX: u64 = 9223372036854775807; -pub const UINTPTR_MAX: i32 = -1; -pub const PTRDIFF_MIN: i64 = -9223372036854775808; -pub const PTRDIFF_MAX: u64 = 9223372036854775807; -pub const SIG_ATOMIC_MIN: i32 = -2147483648; -pub const SIG_ATOMIC_MAX: u32 = 2147483647; -pub const SIZE_MAX: i32 = -1; -pub const WINT_MIN: u32 = 0; -pub const WINT_MAX: u32 = 4294967295; -pub type int_least8_t = ::std::os::raw::c_schar; -pub type int_least16_t = ::std::os::raw::c_short; -pub type int_least32_t = ::std::os::raw::c_int; -pub type int_least64_t = ::std::os::raw::c_long; -pub type uint_least8_t = ::std::os::raw::c_uchar; -pub type uint_least16_t = ::std::os::raw::c_ushort; -pub type uint_least32_t = ::std::os::raw::c_uint; -pub type uint_least64_t = ::std::os::raw::c_ulong; -pub type int_fast8_t = ::std::os::raw::c_schar; -pub type int_fast16_t = ::std::os::raw::c_long; -pub type int_fast32_t = ::std::os::raw::c_long; -pub type int_fast64_t = ::std::os::raw::c_long; -pub type uint_fast8_t = ::std::os::raw::c_uchar; -pub type uint_fast16_t = ::std::os::raw::c_ulong; -pub type uint_fast32_t = ::std::os::raw::c_ulong; -pub type uint_fast64_t = ::std::os::raw::c_ulong; -pub type intmax_t = ::std::os::raw::c_long; -pub type uintmax_t = ::std::os::raw::c_ulong; -pub type wchar_t = ::std::os::raw::c_int; -#[repr(C)] -#[derive(Debug, Copy, Clone)] -pub struct max_align_t { - pub __clang_max_align_nonce1: ::std::os::raw::c_longlong, - pub __bindgen_padding_0: u64, - pub __clang_max_align_nonce2: f64, -} -pub const DLDeviceType_kDLCPU: DLDeviceType = 1; -pub const DLDeviceType_kDLGPU: DLDeviceType = 2; -pub const DLDeviceType_kDLCPUPinned: DLDeviceType = 3; -pub const DLDeviceType_kDLOpenCL: DLDeviceType = 4; -pub const DLDeviceType_kDLMetal: DLDeviceType = 8; -pub const DLDeviceType_kDLVPI: DLDeviceType = 9; -pub const DLDeviceType_kDLROCM: DLDeviceType = 10; -/// \brief The device type in DLContext. -pub type DLDeviceType = u32; -/// \brief A Device context for Tensor and operator. -#[repr(C)] -#[derive(Debug, Copy, Clone)] -pub struct DLContext { - /// \brief The device type used in the device. - pub device_type: DLDeviceType, - /// \brief The device index - pub device_id: ::std::os::raw::c_int, -} -pub const DLDataTypeCode_kDLInt: DLDataTypeCode = 0; -pub const DLDataTypeCode_kDLUInt: DLDataTypeCode = 1; -pub const DLDataTypeCode_kDLFloat: DLDataTypeCode = 2; -/// \brief The type code options DLDataType. -pub type DLDataTypeCode = u32; -/// \brief The data type the tensor can hold. +pub const TVM_VERSION : & 'static [ u8 ; 8usize ] = b"0.5.dev\0" ; pub const DLPACK_VERSION : u32 = 16 ; pub const _STDINT_H : u32 = 1 ; pub const _FEATURES_H : u32 = 1 ; pub const _DEFAULT_SOURCE : u32 = 1 ; pub const __USE_ISOC11 : u32 = 1 ; pub const __USE_ISOC99 : u32 = 1 ; pub const __USE_ISOC95 : u32 = 1 ; pub const __USE_POSIX_IMPLICITLY : u32 = 1 ; pub const _POSIX_SOURCE : u32 = 1 ; pub const _POSIX_C_SOURCE : u32 = 200809 ; pub const __USE_POSIX : u32 = 1 ; pub const __USE_POSIX2 : u32 = 1 ; pub const __USE_POSIX199309 : u32 = 1 ; pub const __USE_POSIX199506 : u32 = 1 ; pub const __USE_XOPEN2K : u32 = 1 ; pub const __USE_XOPEN2K8 : u32 = 1 ; pub const _ATFILE_SOURCE : u32 = 1 ; pub const __USE_MISC : u32 = 1 ; pub const __USE_ATFILE : u32 = 1 ; pub const __USE_FORTIFY_LEVEL : u32 = 0 ; pub const __GLIBC_USE_DEPRECATED_GETS : u32 = 0 ; pub const _STDC_PREDEF_H : u32 = 1 ; pub const __STDC_IEC_559__ : u32 = 1 ; pub const __STDC_IEC_559_COMPLEX__ : u32 = 1 ; pub const __STDC_ISO_10646__ : u32 = 201706 ; pub const __STDC_NO_THREADS__ : u32 = 1 ; pub const __GNU_LIBRARY__ : u32 = 6 ; pub const __GLIBC__ : u32 = 2 ; pub const __GLIBC_MINOR__ : u32 = 27 ; pub const _SYS_CDEFS_H : u32 = 1 ; pub const __glibc_c99_flexarr_available : u32 = 1 ; pub const __WORDSIZE : u32 = 64 ; pub const __WORDSIZE_TIME64_COMPAT32 : u32 = 1 ; pub const __SYSCALL_WORDSIZE : u32 = 64 ; pub const __HAVE_GENERIC_SELECTION : u32 = 1 ; pub const __GLIBC_USE_LIB_EXT2 : u32 = 0 ; pub const __GLIBC_USE_IEC_60559_BFP_EXT : u32 = 0 ; pub const __GLIBC_USE_IEC_60559_FUNCS_EXT : u32 = 0 ; pub const __GLIBC_USE_IEC_60559_TYPES_EXT : u32 = 0 ; pub const _BITS_TYPES_H : u32 = 1 ; pub const _BITS_TYPESIZES_H : u32 = 1 ; pub const __OFF_T_MATCHES_OFF64_T : u32 = 1 ; pub const __INO_T_MATCHES_INO64_T : u32 = 1 ; pub const __RLIM_T_MATCHES_RLIM64_T : u32 = 1 ; pub const __FD_SETSIZE : u32 = 1024 ; pub const _BITS_WCHAR_H : u32 = 1 ; pub const _BITS_STDINT_INTN_H : u32 = 1 ; pub const _BITS_STDINT_UINTN_H : u32 = 1 ; pub const INT8_MIN : i32 = -128 ; pub const INT16_MIN : i32 = -32768 ; pub const INT32_MIN : i32 = -2147483648 ; pub const INT8_MAX : u32 = 127 ; pub const INT16_MAX : u32 = 32767 ; pub const INT32_MAX : u32 = 2147483647 ; pub const UINT8_MAX : u32 = 255 ; pub const UINT16_MAX : u32 = 65535 ; pub const UINT32_MAX : u32 = 4294967295 ; pub const INT_LEAST8_MIN : i32 = -128 ; pub const INT_LEAST16_MIN : i32 = -32768 ; pub const INT_LEAST32_MIN : i32 = -2147483648 ; pub const INT_LEAST8_MAX : u32 = 127 ; pub const INT_LEAST16_MAX : u32 = 32767 ; pub const INT_LEAST32_MAX : u32 = 2147483647 ; pub const UINT_LEAST8_MAX : u32 = 255 ; pub const UINT_LEAST16_MAX : u32 = 65535 ; pub const UINT_LEAST32_MAX : u32 = 4294967295 ; pub const INT_FAST8_MIN : i32 = -128 ; pub const INT_FAST16_MIN : i64 = -9223372036854775808 ; pub const INT_FAST32_MIN : i64 = -9223372036854775808 ; pub const INT_FAST8_MAX : u32 = 127 ; pub const INT_FAST16_MAX : u64 = 9223372036854775807 ; pub const INT_FAST32_MAX : u64 = 9223372036854775807 ; pub const UINT_FAST8_MAX : u32 = 255 ; pub const UINT_FAST16_MAX : i32 = -1 ; pub const UINT_FAST32_MAX : i32 = -1 ; pub const INTPTR_MIN : i64 = -9223372036854775808 ; pub const INTPTR_MAX : u64 = 9223372036854775807 ; pub const UINTPTR_MAX : i32 = -1 ; pub const PTRDIFF_MIN : i64 = -9223372036854775808 ; pub const PTRDIFF_MAX : u64 = 9223372036854775807 ; pub const SIG_ATOMIC_MIN : i32 = -2147483648 ; pub const SIG_ATOMIC_MAX : u32 = 2147483647 ; pub const SIZE_MAX : i32 = -1 ; pub const WINT_MIN : u32 = 0 ; pub const WINT_MAX : u32 = 4294967295 ; pub type __u_char = :: std :: os :: raw :: c_uchar ; pub type __u_short = :: std :: os :: raw :: c_ushort ; pub type __u_int = :: std :: os :: raw :: c_uint ; pub type __u_long = :: std :: os :: raw :: c_ulong ; pub type __int8_t = :: std :: os :: raw :: c_schar ; pub type __uint8_t = :: std :: os :: raw :: c_uchar ; pub type __int16_t = :: std :: os :: raw :: c_short ; pub type __uint16_t = :: std :: os :: raw :: c_ushort ; pub type __int32_t = :: std :: os :: raw :: c_int ; pub type __uint32_t = :: std :: os :: raw :: c_uint ; pub type __int64_t = :: std :: os :: raw :: c_long ; pub type __uint64_t = :: std :: os :: raw :: c_ulong ; pub type __quad_t = :: std :: os :: raw :: c_long ; pub type __u_quad_t = :: std :: os :: raw :: c_ulong ; pub type __intmax_t = :: std :: os :: raw :: c_long ; pub type __uintmax_t = :: std :: os :: raw :: c_ulong ; pub type __dev_t = :: std :: os :: raw :: c_ulong ; pub type __uid_t = :: std :: os :: raw :: c_uint ; pub type __gid_t = :: std :: os :: raw :: c_uint ; pub type __ino_t = :: std :: os :: raw :: c_ulong ; pub type __ino64_t = :: std :: os :: raw :: c_ulong ; pub type __mode_t = :: std :: os :: raw :: c_uint ; pub type __nlink_t = :: std :: os :: raw :: c_ulong ; pub type __off_t = :: std :: os :: raw :: c_long ; pub type __off64_t = :: std :: os :: raw :: c_long ; pub type __pid_t = :: std :: os :: raw :: c_int ; # [ repr ( C ) ] # [ derive ( Debug , Copy , Clone , PartialEq , Eq ) ] pub struct __fsid_t { pub __val : [ :: std :: os :: raw :: c_int ; 2usize ] , } pub type __clock_t = :: std :: os :: raw :: c_long ; pub type __rlim_t = :: std :: os :: raw :: c_ulong ; pub type __rlim64_t = :: std :: os :: raw :: c_ulong ; pub type __id_t = :: std :: os :: raw :: c_uint ; pub type __time_t = :: std :: os :: raw :: c_long ; pub type __useconds_t = :: std :: os :: raw :: c_uint ; pub type __suseconds_t = :: std :: os :: raw :: c_long ; pub type __daddr_t = :: std :: os :: raw :: c_int ; pub type __key_t = :: std :: os :: raw :: c_int ; pub type __clockid_t = :: std :: os :: raw :: c_int ; pub type __timer_t = * mut :: std :: os :: raw :: c_void ; pub type __blksize_t = :: std :: os :: raw :: c_long ; pub type __blkcnt_t = :: std :: os :: raw :: c_long ; pub type __blkcnt64_t = :: std :: os :: raw :: c_long ; pub type __fsblkcnt_t = :: std :: os :: raw :: c_ulong ; pub type __fsblkcnt64_t = :: std :: os :: raw :: c_ulong ; pub type __fsfilcnt_t = :: std :: os :: raw :: c_ulong ; pub type __fsfilcnt64_t = :: std :: os :: raw :: c_ulong ; pub type __fsword_t = :: std :: os :: raw :: c_long ; pub type __ssize_t = :: std :: os :: raw :: c_long ; pub type __syscall_slong_t = :: std :: os :: raw :: c_long ; pub type __syscall_ulong_t = :: std :: os :: raw :: c_ulong ; pub type __loff_t = __off64_t ; pub type __caddr_t = * mut :: std :: os :: raw :: c_char ; pub type __intptr_t = :: std :: os :: raw :: c_long ; pub type __socklen_t = :: std :: os :: raw :: c_uint ; pub type __sig_atomic_t = :: std :: os :: raw :: c_int ; pub type int_least8_t = :: std :: os :: raw :: c_schar ; pub type int_least16_t = :: std :: os :: raw :: c_short ; pub type int_least32_t = :: std :: os :: raw :: c_int ; pub type int_least64_t = :: std :: os :: raw :: c_long ; pub type uint_least8_t = :: std :: os :: raw :: c_uchar ; pub type uint_least16_t = :: std :: os :: raw :: c_ushort ; pub type uint_least32_t = :: std :: os :: raw :: c_uint ; pub type uint_least64_t = :: std :: os :: raw :: c_ulong ; pub type int_fast8_t = :: std :: os :: raw :: c_schar ; pub type int_fast16_t = :: std :: os :: raw :: c_long ; pub type int_fast32_t = :: std :: os :: raw :: c_long ; pub type int_fast64_t = :: std :: os :: raw :: c_long ; pub type uint_fast8_t = :: std :: os :: raw :: c_uchar ; pub type uint_fast16_t = :: std :: os :: raw :: c_ulong ; pub type uint_fast32_t = :: std :: os :: raw :: c_ulong ; pub type uint_fast64_t = :: std :: os :: raw :: c_ulong ; pub type intmax_t = __intmax_t ; pub type uintmax_t = __uintmax_t ; pub type wchar_t = :: std :: os :: raw :: c_int ; + /// \brief CPU device + pub const DLDeviceType_kDLCPU : DLDeviceType = 1 ; + /// \brief CUDA GPU device + pub const DLDeviceType_kDLGPU : DLDeviceType = 2 ; + /// \brief Pinned CUDA GPU device by cudaMallocHost +/// \note kDLCPUPinned = kDLCPU | kDLGPU + pub const DLDeviceType_kDLCPUPinned : DLDeviceType = 3 ; + /// \brief OpenCL devices. + pub const DLDeviceType_kDLOpenCL : DLDeviceType = 4 ; + /// \brief Vulkan buffer for next generation graphics. + pub const DLDeviceType_kDLVulkan : DLDeviceType = 7 ; + /// \brief Metal for Apple GPU. + pub const DLDeviceType_kDLMetal : DLDeviceType = 8 ; + /// \brief Verilog simulator buffer + pub const DLDeviceType_kDLVPI : DLDeviceType = 9 ; + /// \brief ROCm GPUs for AMD GPUs + pub const DLDeviceType_kDLROCM : DLDeviceType = 10 ; + /// \brief Reserved extension device type, +/// used for quickly test extension device +/// The semantics can differ depending on the implementation. + pub const DLDeviceType_kDLExtDev : DLDeviceType = 12 ; + /// \brief The device type in DLContext. + pub type DLDeviceType = u32 ; + /// \brief A Device context for Tensor and operator. + # [ repr ( C ) ] # [ derive ( Debug , Copy , Clone , PartialEq , Eq ) ] pub struct DLContext { + /// \brief The device type used in the device. + pub device_type : DLDeviceType , + /// \brief The device index + pub device_id : :: std :: os :: raw :: c_int , } pub const DLDataTypeCode_kDLInt : DLDataTypeCode = 0 ; pub const DLDataTypeCode_kDLUInt : DLDataTypeCode = 1 ; pub const DLDataTypeCode_kDLFloat : DLDataTypeCode = 2 ; + /// \brief The type code options DLDataType. + pub type DLDataTypeCode = u32 ; + /// \brief The data type the tensor can hold. /// /// Examples /// - float: type_code = 2, bits = 32, lanes=1 /// - float4(vectorized 4 float): type_code = 2, bits = 32, lanes=4 -/// - int8: type_code = 0, bits = 8, lanes=1 -#[repr(C)] -#[derive(Debug, Copy, Clone)] -pub struct DLDataType { - /// \brief Type code of base types. - /// We keep it uint8_t instead of DLDataTypeCode for minimal memory - /// footprint, but the value should be one of DLDataTypeCode enum values. - /// - pub code: u8, - /// \brief Number of bits, common choices are 8, 16, 32. - pub bits: u8, - /// \brief Number of lanes in the type, used for vector types. - pub lanes: u16, -} -/// \brief Plain C Tensor object, does not manage memory. -#[repr(C)] -#[derive(Debug, Copy, Clone)] -pub struct DLTensor { - /// \brief The opaque data pointer points to the allocated data. - /// This will be CUDA device pointer or cl_mem handle in OpenCL. - /// This pointer is always aligns to 256 bytes as in CUDA. - pub data: *mut ::std::os::raw::c_void, - /// \brief The device context of the tensor - pub ctx: DLContext, - /// \brief Number of dimensions - pub ndim: ::std::os::raw::c_int, - /// \brief The data type of the pointer - pub dtype: DLDataType, - /// \brief The shape of the tensor - pub shape: *mut i64, - /// \brief strides of the tensor, - /// can be NULL, indicating tensor is compact. - pub strides: *mut i64, - /// \brief The offset in bytes to the beginning pointer to data - pub byte_offset: u64, -} -/// \brief C Tensor object, manage memory of DLTensor. This data structure is +/// - int8: type_code = 0, bits = 8, lanes=1 + # [ repr ( C ) ] # [ derive ( Debug , Copy , Clone , PartialEq , Eq ) ] pub struct DLDataType { + /// \brief Type code of base types. + /// We keep it uint8_t instead of DLDataTypeCode for minimal memory + /// footprint, but the value should be one of DLDataTypeCode enum values. + /// + pub code : u8 , + /// \brief Number of bits, common choices are 8, 16, 32. + pub bits : u8 , + /// \brief Number of lanes in the type, used for vector types. + pub lanes : u16 , } + /// \brief Plain C Tensor object, does not manage memory. + # [ repr ( C ) ] # [ derive ( Debug , Copy , Clone , PartialEq , Eq ) ] pub struct DLTensor { + /// \brief The opaque data pointer points to the allocated data. + /// This will be CUDA device pointer or cl_mem handle in OpenCL. + /// This pointer is always aligns to 256 bytes as in CUDA. + pub data : * mut :: std :: os :: raw :: c_void , + /// \brief The device context of the tensor + pub ctx : DLContext , + /// \brief Number of dimensions + pub ndim : :: std :: os :: raw :: c_int , + /// \brief The data type of the pointer + pub dtype : DLDataType , + /// \brief The shape of the tensor + pub shape : * mut i64 , + /// \brief strides of the tensor, + /// can be NULL, indicating tensor is compact. + pub strides : * mut i64 , + /// \brief The offset in bytes to the beginning pointer to data + pub byte_offset : u64 , } + /// \brief C Tensor object, manage memory of DLTensor. This data structure is /// intended to faciliate the borrowing of DLTensor by another framework. It is /// not meant to transfer the tensor. When the borrowing framework doesn't need /// the tensor, it should call the deleter to notify the host that the resource -/// is no longer needed. -#[repr(C)] -#[derive(Debug, Copy, Clone)] -pub struct DLManagedTensor { - /// \brief DLTensor which is being memory managed - pub dl_tensor: DLTensor, - /// \brief the context of the original host framework of DLManagedTensor in - /// which DLManagedTensor is used in the framework. It can also be NULL. - pub manager_ctx: *mut ::std::os::raw::c_void, - /// \brief Destructor signature void (*)(void*) - this should be called - /// to destruct manager_ctx which holds the DLManagedTensor. It can be NULL - /// if there is no way for the caller to provide a reasonable destructor. - pub deleter: ::std::option::Option, -} -/// \brief type of array index. -pub type tvm_index_t = i64; -pub const TVMDeviceExtType_kDLAOCL: TVMDeviceExtType = 5; -pub const TVMDeviceExtType_kDLSDAccel: TVMDeviceExtType = 6; -pub const TVMDeviceExtType_kDLVulkan: TVMDeviceExtType = 7; -pub const TVMDeviceExtType_kOpenGL: TVMDeviceExtType = 11; -pub const TVMDeviceExtType_kExtDev: TVMDeviceExtType = 12; -/// \brief Extension device types in TVM -pub type TVMDeviceExtType = u32; -pub const TVMTypeCode_kHandle: TVMTypeCode = 3; -pub const TVMTypeCode_kNull: TVMTypeCode = 4; -pub const TVMTypeCode_kTVMType: TVMTypeCode = 5; -pub const TVMTypeCode_kTVMContext: TVMTypeCode = 6; -pub const TVMTypeCode_kArrayHandle: TVMTypeCode = 7; -pub const TVMTypeCode_kNodeHandle: TVMTypeCode = 8; -pub const TVMTypeCode_kModuleHandle: TVMTypeCode = 9; -pub const TVMTypeCode_kFuncHandle: TVMTypeCode = 10; -pub const TVMTypeCode_kStr: TVMTypeCode = 11; -pub const TVMTypeCode_kBytes: TVMTypeCode = 12; -pub const TVMTypeCode_kNDArrayContainer: TVMTypeCode = 13; -pub const TVMTypeCode_kExtBegin: TVMTypeCode = 15; -pub const TVMTypeCode_kNNVMFirst: TVMTypeCode = 16; -pub const TVMTypeCode_kNNVMLast: TVMTypeCode = 20; -pub const TVMTypeCode_kExtReserveEnd: TVMTypeCode = 64; -pub const TVMTypeCode_kExtEnd: TVMTypeCode = 128; -/// \brief The type code in TVMType -/// \note TVMType is used in two places. -pub type TVMTypeCode = u32; -/// \brief The data type used in TVM Runtime. +/// is no longer needed. + # [ repr ( C ) ] # [ derive ( Debug , Copy , Clone , PartialEq , Eq ) ] pub struct DLManagedTensor { + /// \brief DLTensor which is being memory managed + pub dl_tensor : DLTensor , + /// \brief the context of the original host framework of DLManagedTensor in + /// which DLManagedTensor is used in the framework. It can also be NULL. + pub manager_ctx : * mut :: std :: os :: raw :: c_void , + /// \brief Destructor signature void (*)(void*) - this should be called + /// to destruct manager_ctx which holds the DLManagedTensor. It can be NULL + /// if there is no way for the caller to provide a reasonable destructor. + pub deleter : :: std :: option :: Option < unsafe extern "C" fn ( self_ : * mut DLManagedTensor ) > , } + /// \brief type of array index. + pub type tvm_index_t = i64 ; pub const TVMDeviceExtType_kDLAOCL : TVMDeviceExtType = 5 ; pub const TVMDeviceExtType_kDLSDAccel : TVMDeviceExtType = 6 ; pub const TVMDeviceExtType_kOpenGL : TVMDeviceExtType = 11 ; + /// \brief Extension device types in TVM + pub type TVMDeviceExtType = u32 ; pub const TVMTypeCode_kHandle : TVMTypeCode = 3 ; pub const TVMTypeCode_kNull : TVMTypeCode = 4 ; pub const TVMTypeCode_kTVMType : TVMTypeCode = 5 ; pub const TVMTypeCode_kTVMContext : TVMTypeCode = 6 ; pub const TVMTypeCode_kArrayHandle : TVMTypeCode = 7 ; pub const TVMTypeCode_kNodeHandle : TVMTypeCode = 8 ; pub const TVMTypeCode_kModuleHandle : TVMTypeCode = 9 ; pub const TVMTypeCode_kFuncHandle : TVMTypeCode = 10 ; pub const TVMTypeCode_kStr : TVMTypeCode = 11 ; pub const TVMTypeCode_kBytes : TVMTypeCode = 12 ; pub const TVMTypeCode_kNDArrayContainer : TVMTypeCode = 13 ; pub const TVMTypeCode_kExtBegin : TVMTypeCode = 15 ; pub const TVMTypeCode_kNNVMFirst : TVMTypeCode = 16 ; pub const TVMTypeCode_kNNVMLast : TVMTypeCode = 20 ; pub const TVMTypeCode_kExtReserveEnd : TVMTypeCode = 64 ; pub const TVMTypeCode_kExtEnd : TVMTypeCode = 128 ; + /// \brief The type code in TVMType +/// \note TVMType is used in two places. + pub type TVMTypeCode = u32 ; + /// \brief The data type used in TVM Runtime. /// /// Examples /// - float: type_code = 2, bits = 32, lanes=1 /// - float4(vectorized 4 float): type_code = 2, bits = 32, lanes=4 /// - int8: type_code = 0, bits = 8, lanes=1 /// -/// \note Arguments TVM API function always takes bits=64 and lanes=1 -pub type TVMType = DLDataType; -/// \brief The Device information, abstract away common device types. -pub type TVMContext = DLContext; -/// \brief The tensor array stucture to TVM API. -pub type TVMArray = DLTensor; -/// \brief the array handle -pub type TVMArrayHandle = *mut TVMArray; -/// \brief Union type of values -/// being passed through API and function calls. -#[repr(C)] -#[derive(Copy, Clone)] -pub union TVMValue { - pub v_int64: i64, - pub v_float64: f64, - pub v_handle: *mut ::std::os::raw::c_void, - pub v_str: *const ::std::os::raw::c_char, - pub v_type: TVMType, - pub v_ctx: TVMContext, - _bindgen_union_align: u64, -} -/// \brief Byte array type used to pass in byte array -/// When kBytes is used as data type. -#[repr(C)] -#[derive(Debug, Copy, Clone)] -pub struct TVMByteArray { - pub data: *const ::std::os::raw::c_char, - pub size: usize, -} -/// \brief Handle to TVM runtime modules. -pub type TVMModuleHandle = *mut ::std::os::raw::c_void; -/// \brief Handle to packed function handle. -pub type TVMFunctionHandle = *mut ::std::os::raw::c_void; -/// \brief Handle to hold return value. -pub type TVMRetValueHandle = *mut ::std::os::raw::c_void; -/// \brief The stream that is specific to device -/// can be NULL, which indicates the default one. -pub type TVMStreamHandle = *mut ::std::os::raw::c_void; -extern "C" { - /// \brief Used for implementing C API function. - /// Set last error message before return. - /// \param msg The error message to be set. - pub fn TVMAPISetLastError(msg: *const ::std::os::raw::c_char); -} -extern "C" { - /// \brief return str message of the last error - /// all function in this file will return 0 when success - /// and -1 when an error occured, - /// TVMGetLastError can be called to retrieve the error - /// - /// this function is threadsafe and can be called by different thread - /// \return error info - pub fn TVMGetLastError() -> *const ::std::os::raw::c_char; -} -extern "C" { - /// \brief Load module from file. - /// \param file_name The file name to load the module from. - /// \param format The format of the module. - /// \param out The result module - /// - /// \return 0 when success, -1 when failure happens - /// \note The resulting module do not contain import relation. - /// It can be reconstructed by TVMModImport. - pub fn TVMModLoadFromFile( - file_name: *const ::std::os::raw::c_char, - format: *const ::std::os::raw::c_char, - out: *mut TVMModuleHandle, - ) -> ::std::os::raw::c_int; -} -extern "C" { - /// \brief Add dep to mod's dependency. - /// This allows functions in this module to use modules. - /// - /// \param mod The module handle. - /// \param dep The dependent module to be imported. - /// \return 0 when success, -1 when failure happens - pub fn TVMModImport(mod_: TVMModuleHandle, dep: TVMModuleHandle) -> ::std::os::raw::c_int; -} -extern "C" { - /// \brief Get function from the module. - /// \param mod The module handle. - /// \param func_name The name of the function. - /// \param query_imports Whether to query imported modules - /// \param out The result function, can be NULL if it is not available. - /// \return 0 when no error is thrown, -1 when failure happens - pub fn TVMModGetFunction( - mod_: TVMModuleHandle, - func_name: *const ::std::os::raw::c_char, - query_imports: ::std::os::raw::c_int, - out: *mut TVMFunctionHandle, - ) -> ::std::os::raw::c_int; -} -extern "C" { - /// \brief Free front-end extension type resource. - /// \param handle The extension handle. - /// \param type_code The type of of the extension type. - /// \return 0 when success, -1 when failure happens - pub fn TVMExtTypeFree( - handle: *mut ::std::os::raw::c_void, - type_code: ::std::os::raw::c_int, - ) -> ::std::os::raw::c_int; -} -extern "C" { - /// \brief Free the Module - /// \param mod The module to be freed. - /// - /// \note This may not free up the module's resources. - /// If there is active TVMFunctionHandle uses the module - /// Or if this module is imported by another active module. - /// - /// The all functions remains valid until TVMFuncFree is called. - /// \return 0 when success, -1 when failure happens - pub fn TVMModFree(mod_: TVMModuleHandle) -> ::std::os::raw::c_int; -} -extern "C" { - /// \brief Free the function when it is no longer needed. - /// \param func The function handle - /// \return 0 when success, -1 when failure happens - pub fn TVMFuncFree(func: TVMFunctionHandle) -> ::std::os::raw::c_int; -} -extern "C" { - /// \brief Call a Packed TVM Function. - /// - /// \param func node handle of the function. - /// \param arg_values The arguments - /// \param type_codes The type codes of the arguments - /// \param num_args Number of arguments. - /// - /// \param ret_val The return value. - /// \param ret_type_code the type code of return value. - /// - /// \return 0 when success, -1 when failure happens - /// \note TVM calls always exchanges with type bits=64, lanes=1 - /// - /// \note API calls always exchanges with type bits=64, lanes=1 - /// If API call returns container handles (e.g. FunctionHandle) - /// these handles should be managed by the front-end. - /// The front-end need to call free function (e.g. TVMFuncFree) - /// to free these handles. - pub fn TVMFuncCall( - func: TVMFunctionHandle, - arg_values: *mut TVMValue, - type_codes: *mut ::std::os::raw::c_int, - num_args: ::std::os::raw::c_int, - ret_val: *mut TVMValue, - ret_type_code: *mut ::std::os::raw::c_int, - ) -> ::std::os::raw::c_int; -} -extern "C" { - /// \brief Set the return value of TVMPackedCFunc. - /// - /// This function is called by TVMPackedCFunc to set the return value. - /// When this function is not called, the function returns null by default. - /// - /// \param ret The return value handle, pass by ret in TVMPackedCFunc - /// \param value The value to be returned. - /// \param type_code The type of the value to be returned. - /// \param num_ret Number of return values, for now only 1 is supported. - pub fn TVMCFuncSetReturn( - ret: TVMRetValueHandle, - value: *mut TVMValue, - type_code: *mut ::std::os::raw::c_int, - num_ret: ::std::os::raw::c_int, - ) -> ::std::os::raw::c_int; -} -extern "C" { - /// \brief Inplace translate callback argument value to return value. - /// This is only needed for non-POD arguments. - /// - /// \param value The value to be translated. - /// \param code The type code to be translated. - /// \note This function will do a shallow copy when necessary. - /// - /// \return 0 when success, -1 when failure happens. - pub fn TVMCbArgToReturn( - value: *mut TVMValue, - code: ::std::os::raw::c_int, - ) -> ::std::os::raw::c_int; -} -/// \brief C type of packed function. +/// \note Arguments TVM API function always takes bits=64 and lanes=1 + pub type TVMType = DLDataType ; + /// \brief The Device information, abstract away common device types. + pub type TVMContext = DLContext ; + /// \brief The tensor array stucture to TVM API. + pub type TVMArray = DLTensor ; + /// \brief the array handle + pub type TVMArrayHandle = * mut TVMArray ; + /// \brief Union type of values +/// being passed through API and function calls. + # [ repr ( C ) ] # [ derive ( Copy , Clone ) ] pub union TVMValue { pub v_int64 : i64 , pub v_float64 : f64 , pub v_handle : * mut :: std :: os :: raw :: c_void , pub v_str : * const :: std :: os :: raw :: c_char , pub v_type : TVMType , pub v_ctx : TVMContext , _bindgen_union_align : u64 , } + /// \brief Byte array type used to pass in byte array +/// When kBytes is used as data type. + # [ repr ( C ) ] # [ derive ( Debug , Copy , Clone , PartialEq , Eq ) ] pub struct TVMByteArray { pub data : * const :: std :: os :: raw :: c_char , pub size : usize , } + /// \brief Handle to TVM runtime modules. + pub type TVMModuleHandle = * mut :: std :: os :: raw :: c_void ; + /// \brief Handle to packed function handle. + pub type TVMFunctionHandle = * mut :: std :: os :: raw :: c_void ; + /// \brief Handle to hold return value. + pub type TVMRetValueHandle = * mut :: std :: os :: raw :: c_void ; + /// \brief The stream that is specific to device +/// can be NULL, which indicates the default one. + pub type TVMStreamHandle = * mut :: std :: os :: raw :: c_void ; extern "C" { + /// \brief Used for implementing C API function. +/// Set last error message before return. +/// \param msg The error message to be set. + pub fn TVMAPISetLastError ( msg : * const :: std :: os :: raw :: c_char ) ; } extern "C" { + /// \brief return str message of the last error +/// all function in this file will return 0 when success +/// and -1 when an error occured, +/// TVMGetLastError can be called to retrieve the error +/// +/// this function is threadsafe and can be called by different thread +/// \return error info + pub fn TVMGetLastError ( ) -> * const :: std :: os :: raw :: c_char ; } extern "C" { + /// \brief Load module from file. +/// \param file_name The file name to load the module from. +/// \param format The format of the module. +/// \param out The result module +/// +/// \return 0 when success, -1 when failure happens +/// \note The resulting module do not contain import relation. +/// It can be reconstructed by TVMModImport. + pub fn TVMModLoadFromFile ( file_name : * const :: std :: os :: raw :: c_char , format : * const :: std :: os :: raw :: c_char , out : * mut TVMModuleHandle ) -> :: std :: os :: raw :: c_int ; } extern "C" { + /// \brief Add dep to mod's dependency. +/// This allows functions in this module to use modules. +/// +/// \param mod The module handle. +/// \param dep The dependent module to be imported. +/// \return 0 when success, -1 when failure happens + pub fn TVMModImport ( mod_ : TVMModuleHandle , dep : TVMModuleHandle ) -> :: std :: os :: raw :: c_int ; } extern "C" { + /// \brief Get function from the module. +/// \param mod The module handle. +/// \param func_name The name of the function. +/// \param query_imports Whether to query imported modules +/// \param out The result function, can be NULL if it is not available. +/// \return 0 when no error is thrown, -1 when failure happens + pub fn TVMModGetFunction ( mod_ : TVMModuleHandle , func_name : * const :: std :: os :: raw :: c_char , query_imports : :: std :: os :: raw :: c_int , out : * mut TVMFunctionHandle ) -> :: std :: os :: raw :: c_int ; } extern "C" { + /// \brief Free front-end extension type resource. +/// \param handle The extension handle. +/// \param type_code The type of of the extension type. +/// \return 0 when success, -1 when failure happens + pub fn TVMExtTypeFree ( handle : * mut :: std :: os :: raw :: c_void , type_code : :: std :: os :: raw :: c_int ) -> :: std :: os :: raw :: c_int ; } extern "C" { + /// \brief Free the Module +/// \param mod The module to be freed. +/// +/// \note This may not free up the module's resources. +/// If there is active TVMFunctionHandle uses the module +/// Or if this module is imported by another active module. +/// +/// The all functions remains valid until TVMFuncFree is called. +/// \return 0 when success, -1 when failure happens + pub fn TVMModFree ( mod_ : TVMModuleHandle ) -> :: std :: os :: raw :: c_int ; } extern "C" { + /// \brief Free the function when it is no longer needed. +/// \param func The function handle +/// \return 0 when success, -1 when failure happens + pub fn TVMFuncFree ( func : TVMFunctionHandle ) -> :: std :: os :: raw :: c_int ; } extern "C" { + /// \brief Call a Packed TVM Function. +/// +/// \param func node handle of the function. +/// \param arg_values The arguments +/// \param type_codes The type codes of the arguments +/// \param num_args Number of arguments. +/// +/// \param ret_val The return value. +/// \param ret_type_code the type code of return value. +/// +/// \return 0 when success, -1 when failure happens +/// \note TVM calls always exchanges with type bits=64, lanes=1 +/// +/// \note API calls always exchanges with type bits=64, lanes=1 +/// If API call returns container handles (e.g. FunctionHandle) +/// these handles should be managed by the front-end. +/// The front-end need to call free function (e.g. TVMFuncFree) +/// to free these handles. + pub fn TVMFuncCall ( func : TVMFunctionHandle , arg_values : * mut TVMValue , type_codes : * mut :: std :: os :: raw :: c_int , num_args : :: std :: os :: raw :: c_int , ret_val : * mut TVMValue , ret_type_code : * mut :: std :: os :: raw :: c_int ) -> :: std :: os :: raw :: c_int ; } extern "C" { + /// \brief Set the return value of TVMPackedCFunc. +/// +/// This function is called by TVMPackedCFunc to set the return value. +/// When this function is not called, the function returns null by default. +/// +/// \param ret The return value handle, pass by ret in TVMPackedCFunc +/// \param value The value to be returned. +/// \param type_code The type of the value to be returned. +/// \param num_ret Number of return values, for now only 1 is supported. + pub fn TVMCFuncSetReturn ( ret : TVMRetValueHandle , value : * mut TVMValue , type_code : * mut :: std :: os :: raw :: c_int , num_ret : :: std :: os :: raw :: c_int ) -> :: std :: os :: raw :: c_int ; } extern "C" { + /// \brief Inplace translate callback argument value to return value. +/// This is only needed for non-POD arguments. +/// +/// \param value The value to be translated. +/// \param code The type code to be translated. +/// \note This function will do a shallow copy when necessary. +/// +/// \return 0 when success, -1 when failure happens. + pub fn TVMCbArgToReturn ( value : * mut TVMValue , code : :: std :: os :: raw :: c_int ) -> :: std :: os :: raw :: c_int ; } + /// \brief C type of packed function. /// /// \param args The arguments /// \param type_codes The type codes of the arguments @@ -401,370 +219,136 @@ extern "C" { /// \param ret The return value handle. /// \param resource_handle The handle additional resouce handle from fron-end. /// \return 0 if success, -1 if failure happens, set error via TVMAPISetLastError. -/// \sa TVMCFuncSetReturn -pub type TVMPackedCFunc = ::std::option::Option< - unsafe extern "C" fn( - args: *mut TVMValue, - type_codes: *mut ::std::os::raw::c_int, - num_args: ::std::os::raw::c_int, - ret: TVMRetValueHandle, - resource_handle: *mut ::std::os::raw::c_void, - ) -> ::std::os::raw::c_int, ->; -/// \brief C callback to free the resource handle in C packed function. -/// \param resource_handle The handle additional resouce handle from fron-end. -pub type TVMPackedCFuncFinalizer = - ::std::option::Option; -/// \brief Signature for extension function declarer. +/// \sa TVMCFuncSetReturn + pub type TVMPackedCFunc = :: std :: option :: Option < unsafe extern "C" fn ( args : * mut TVMValue , type_codes : * mut :: std :: os :: raw :: c_int , num_args : :: std :: os :: raw :: c_int , ret : TVMRetValueHandle , resource_handle : * mut :: std :: os :: raw :: c_void ) -> :: std :: os :: raw :: c_int > ; + /// \brief C callback to free the resource handle in C packed function. +/// \param resource_handle The handle additional resouce handle from fron-end. + pub type TVMPackedCFuncFinalizer = :: std :: option :: Option < unsafe extern "C" fn ( resource_handle : * mut :: std :: os :: raw :: c_void ) > ; + /// \brief Signature for extension function declarer. /// /// TVM call this function to get the extension functions /// The declarer will call register_func to register function and their name. /// /// \param register_func_handle The register function -/// \return 0 if success, -1 if failure happens -pub type TVMExtensionFuncDeclarer = ::std::option::Option< - unsafe extern "C" fn(register_func_handle: TVMFunctionHandle) -> ::std::os::raw::c_int, ->; -extern "C" { - /// \brief Wrap a TVMPackedCFunc to become a FunctionHandle. - /// - /// The resource_handle will be managed by TVM API, until the function is no longer used. - /// - /// \param func The packed C function. - /// \param resource_handle The resource handle from front-end, can be NULL. - /// \param fin The finalizer on resource handle when the FunctionHandle get freed, can be NULL - /// \param out the result function handle. - /// \return 0 when success, -1 when failure happens - pub fn TVMFuncCreateFromCFunc( - func: TVMPackedCFunc, - resource_handle: *mut ::std::os::raw::c_void, - fin: TVMPackedCFuncFinalizer, - out: *mut TVMFunctionHandle, - ) -> ::std::os::raw::c_int; -} -extern "C" { - /// \brief Register the function to runtime's global table. - /// - /// The registered function then can be pulled by the backend by the name. - /// - /// \param name The name of the function. - /// \param f The function to be registered. - /// \param override Whether allow override already registered function. - pub fn TVMFuncRegisterGlobal( - name: *const ::std::os::raw::c_char, - f: TVMFunctionHandle, - override_: ::std::os::raw::c_int, - ) -> ::std::os::raw::c_int; -} -extern "C" { - /// \brief Get a global function. - /// - /// \param name The name of the function. - /// \param out the result function pointer, NULL if it does not exist. - /// - /// \note The function handle of global function is managed by TVM runtime, - /// So TVMFuncFree is should not be called when it get deleted. - pub fn TVMFuncGetGlobal( - name: *const ::std::os::raw::c_char, - out: *mut TVMFunctionHandle, - ) -> ::std::os::raw::c_int; -} -extern "C" { - /// \brief List all the globally registered function name - /// \param out_size The number of functions - /// \param out_array The array of function names. - /// \return 0 when success, -1 when failure happens - pub fn TVMFuncListGlobalNames( - out_size: *mut ::std::os::raw::c_int, - out_array: *mut *mut *const ::std::os::raw::c_char, - ) -> ::std::os::raw::c_int; -} -extern "C" { - /// \brief Allocate a nd-array's memory, - /// including space of shape, of given spec. - /// - /// \param shape The shape of the array, the data content will be copied to out - /// \param ndim The number of dimension of the array. - /// \param dtype_code The type code of the dtype - /// \param dtype_bits The number of bits of dtype - /// \param dtype_lanes The number of lanes in the dtype. - /// \param device_type The device type of context - /// \param device_id The device id of context. - /// \param out The output handle. - /// \return 0 when success, -1 when failure happens - pub fn TVMArrayAlloc( - shape: *const tvm_index_t, - ndim: ::std::os::raw::c_int, - dtype_code: ::std::os::raw::c_int, - dtype_bits: ::std::os::raw::c_int, - dtype_lanes: ::std::os::raw::c_int, - device_type: ::std::os::raw::c_int, - device_id: ::std::os::raw::c_int, - out: *mut TVMArrayHandle, - ) -> ::std::os::raw::c_int; -} -extern "C" { - /// \brief Free the TVM Array. - /// \param handle The array handle to be freed. - /// \return 0 when success, -1 when failure happens - pub fn TVMArrayFree(handle: TVMArrayHandle) -> ::std::os::raw::c_int; -} -extern "C" { - /// \brief Copy array data from CPU byte array. - /// \param handle The array handle. - /// \param data the data pointer - /// \param nbytes The number of bytes to copy. - /// \return 0 when success, -1 when failure happens - pub fn TVMArrayCopyFromBytes( - handle: TVMArrayHandle, - data: *mut ::std::os::raw::c_void, - nbytes: usize, - ) -> ::std::os::raw::c_int; -} -extern "C" { - /// \brief Copy array data to CPU byte array. - /// \param handle The array handle. - /// \param data the data pointer - /// \param nbytes The number of bytes to copy. - /// \return 0 when success, -1 when failure happens - pub fn TVMArrayCopyToBytes( - handle: TVMArrayHandle, - data: *mut ::std::os::raw::c_void, - nbytes: usize, - ) -> ::std::os::raw::c_int; -} -extern "C" { - /// \brief Copy the array, both from and to must be valid during the copy. - /// \param from The array to be copied from. - /// \param to The target space. - /// \param stream The stream where the copy happens, can be NULL. - /// \return 0 when success, -1 when failure happens - pub fn TVMArrayCopyFromTo( - from: TVMArrayHandle, - to: TVMArrayHandle, - stream: TVMStreamHandle, - ) -> ::std::os::raw::c_int; -} -extern "C" { - /// \brief Produce an array from the DLManagedTensor that shares data memory - /// with the DLManagedTensor. - /// \param from The source DLManagedTensor. - /// \param out The output array handle. - /// \return 0 when success, -1 when failure happens - pub fn TVMArrayFromDLPack( - from: *mut DLManagedTensor, - out: *mut TVMArrayHandle, - ) -> ::std::os::raw::c_int; -} -extern "C" { - /// \brief Produce a DLMangedTensor from the array that shares data memory with - /// the array. - /// \param from The source array. - /// \param out The DLManagedTensor handle. - /// \return 0 when success, -1 when failure happens - pub fn TVMArrayToDLPack( - from: TVMArrayHandle, - out: *mut *mut DLManagedTensor, - ) -> ::std::os::raw::c_int; -} -extern "C" { - /// \brief Delete (free) a DLManagedTensor's data. - /// \param dltensor Pointer to the DLManagedTensor. - pub fn TVMDLManagedTensorCallDeleter(dltensor: *mut DLManagedTensor); -} -extern "C" { - /// \brief Create a new runtime stream. - /// - /// \param device_type The device type of context - /// \param device_id The device id of context - /// \param out The new stream handle - /// \return 0 when success, -1 when failure happens - pub fn TVMStreamCreate( - device_type: ::std::os::raw::c_int, - device_id: ::std::os::raw::c_int, - out: *mut TVMStreamHandle, - ) -> ::std::os::raw::c_int; -} -extern "C" { - /// \brief Free a created stream handle. - /// - /// \param device_type The device type of context - /// \param device_id The device id of context - /// \param stream The stream to be freed - /// \return 0 when success, -1 when failure happens - pub fn TVMStreamFree( - device_type: ::std::os::raw::c_int, - device_id: ::std::os::raw::c_int, - stream: TVMStreamHandle, - ) -> ::std::os::raw::c_int; -} -extern "C" { - /// \brief Set the runtime stream of current thread to be stream. - /// The subsequent calls to the same device_type - /// will use the setted stream handle. - /// The specific type of stream is runtime device dependent. - /// - /// \param device_type The device type of context - /// \param device_id The device id of context. - /// \param handle The stream handle. - /// \return 0 when success, -1 when failure happens - pub fn TVMSetStream( - device_type: ::std::os::raw::c_int, - device_id: ::std::os::raw::c_int, - handle: TVMStreamHandle, - ) -> ::std::os::raw::c_int; -} -extern "C" { - /// \brief Wait until all computations on stream completes. - /// - /// \param device_type The device type of context - /// \param device_id The device id of context. - /// \param stream The stream to be synchronized. - /// \return 0 when success, -1 when failure happens - pub fn TVMSynchronize( - device_type: ::std::os::raw::c_int, - device_id: ::std::os::raw::c_int, - stream: TVMStreamHandle, - ) -> ::std::os::raw::c_int; -} -extern "C" { - /// \brief Synchronize two streams of execution. - /// - /// \param device_type The device type of context - /// \param device_id The device id of context - /// \param src The source stream to synchronize. - /// \param dst The destination stream to synchronize. - /// \return 0 when success, -1 when failure happens - pub fn TVMStreamStreamSynchronize( - device_type: ::std::os::raw::c_int, - device_id: ::std::os::raw::c_int, - src: TVMStreamHandle, - dst: TVMStreamHandle, - ) -> ::std::os::raw::c_int; -} -extern "C" { - /// \brief Backend function for modules to get function - /// from its environment mod_node (its imports and global function). - /// The user do should not call TVMFuncFree on func. - /// - /// \param mod_node The module handle. - /// \param func_name The name of the function. - /// \param out The result function. - /// \return 0 when no error is thrown, -1 when failure happens - pub fn TVMBackendGetFuncFromEnv( - mod_node: *mut ::std::os::raw::c_void, - func_name: *const ::std::os::raw::c_char, - out: *mut TVMFunctionHandle, - ) -> ::std::os::raw::c_int; -} -extern "C" { - /// \brief Backend function to register system-wide library symbol. - /// - /// \param name The name of the symbol - /// \param ptr The symbol address. - /// \return 0 when no error is thrown, -1 when failure happens - pub fn TVMBackendRegisterSystemLibSymbol( - name: *const ::std::os::raw::c_char, - ptr: *mut ::std::os::raw::c_void, - ) -> ::std::os::raw::c_int; -} -extern "C" { - /// \brief Backend function to allocate temporal workspace. - /// - /// \note The result allocate spaced is ensured to be aligned to kTempAllocaAlignment. - /// - /// \param nbytes The size of the space requested. - /// \param device_type The device type which the space will be allocated. - /// \param device_id The device id which the space will be allocated. - /// \param dtype_code_hint The type code of the array elements. Only used in - /// certain backends such as OpenGL. - /// \param dtype_bits_hint The type bits of the array elements. Only used in - /// certain backends such as OpenGL. - /// \return nullptr when error is thrown, a valid ptr if success - pub fn TVMBackendAllocWorkspace( - device_type: ::std::os::raw::c_int, - device_id: ::std::os::raw::c_int, - nbytes: u64, - dtype_code_hint: ::std::os::raw::c_int, - dtype_bits_hint: ::std::os::raw::c_int, - ) -> *mut ::std::os::raw::c_void; -} -extern "C" { - /// \brief Backend function to free temporal workspace. - /// - /// \param ptr The result allocated space pointer. - /// \param device_type The device type which the space will be allocated. - /// \param device_id The device id which the space will be allocated. - /// \return 0 when no error is thrown, -1 when failure happens - /// - /// \sa TVMBackendAllocWorkspace - pub fn TVMBackendFreeWorkspace( - device_type: ::std::os::raw::c_int, - device_id: ::std::os::raw::c_int, - ptr: *mut ::std::os::raw::c_void, - ) -> ::std::os::raw::c_int; -} -/// \brief Environment for TVM parallel task. -#[repr(C)] -#[derive(Debug, Copy, Clone)] -pub struct TVMParallelGroupEnv { - /// \brief Auxiliary used for synchronization - pub sync_handle: *mut ::std::os::raw::c_void, - /// \brief total amount of task - pub num_task: i32, -} -/// \brief The callback function to execute a parallel lambda -/// \param task_id the task id of the function. -/// \param penv The parallel environment backs the execution. -/// \param cdata The supporting closure data. -pub type FTVMParallelLambda = ::std::option::Option< - unsafe extern "C" fn( - task_id: ::std::os::raw::c_int, - penv: *mut TVMParallelGroupEnv, - cdata: *mut ::std::os::raw::c_void, - ) -> ::std::os::raw::c_int, ->; -extern "C" { - /// \brief Backend function for running parallel jobs. - /// - /// \param flambda The parallel function to be launched. - /// \param cdata The closure data. - /// \param num_task Number of tasks to launch, can be 0, means launch - /// with all available threads. - /// - /// \return 0 when no error is thrown, -1 when failure happens - pub fn TVMBackendParallelLaunch( - flambda: FTVMParallelLambda, - cdata: *mut ::std::os::raw::c_void, - num_task: ::std::os::raw::c_int, - ) -> ::std::os::raw::c_int; -} -extern "C" { - /// \brief BSP barrrier between parallel threads - /// \param task_id the task id of the function. - /// \param penv The parallel environment backs the execution. - /// \return 0 when no error is thrown, -1 when failure happens - pub fn TVMBackendParallelBarrier( - task_id: ::std::os::raw::c_int, - penv: *mut TVMParallelGroupEnv, - ) -> ::std::os::raw::c_int; -} -extern "C" { - /// \brief Simple static initialization function. - /// Run f once and set handle to be not null. - /// This function is mainly used for test purpose. - /// - /// \param handle An global address to indicate f - /// \param f The function to be ran - /// \param cdata The closure data to pass to the function. - /// \param nbytes Number of bytes in the closure data. - /// \return 0 when no error is thrown, -1 when failure happens - pub fn TVMBackendRunOnce( - handle: *mut *mut ::std::os::raw::c_void, - f: ::std::option::Option< - unsafe extern "C" fn(arg1: *mut ::std::os::raw::c_void) -> ::std::os::raw::c_int, - >, - cdata: *mut ::std::os::raw::c_void, - nbytes: ::std::os::raw::c_int, - ) -> ::std::os::raw::c_int; -} +/// \return 0 if success, -1 if failure happens + pub type TVMExtensionFuncDeclarer = :: std :: option :: Option < unsafe extern "C" fn ( register_func_handle : TVMFunctionHandle ) -> :: std :: os :: raw :: c_int > ; extern "C" { + /// \brief Wrap a TVMPackedCFunc to become a FunctionHandle. +/// +/// The resource_handle will be managed by TVM API, until the function is no longer used. +/// +/// \param func The packed C function. +/// \param resource_handle The resource handle from front-end, can be NULL. +/// \param fin The finalizer on resource handle when the FunctionHandle get freed, can be NULL +/// \param out the result function handle. +/// \return 0 when success, -1 when failure happens + pub fn TVMFuncCreateFromCFunc ( func : TVMPackedCFunc , resource_handle : * mut :: std :: os :: raw :: c_void , fin : TVMPackedCFuncFinalizer , out : * mut TVMFunctionHandle ) -> :: std :: os :: raw :: c_int ; } extern "C" { + /// \brief Register the function to runtime's global table. +/// +/// The registered function then can be pulled by the backend by the name. +/// +/// \param name The name of the function. +/// \param f The function to be registered. +/// \param override Whether allow override already registered function. + pub fn TVMFuncRegisterGlobal ( name : * const :: std :: os :: raw :: c_char , f : TVMFunctionHandle , override_ : :: std :: os :: raw :: c_int ) -> :: std :: os :: raw :: c_int ; } extern "C" { + /// \brief Get a global function. +/// +/// \param name The name of the function. +/// \param out the result function pointer, NULL if it does not exist. +/// +/// \note The function handle of global function is managed by TVM runtime, +/// So TVMFuncFree is should not be called when it get deleted. + pub fn TVMFuncGetGlobal ( name : * const :: std :: os :: raw :: c_char , out : * mut TVMFunctionHandle ) -> :: std :: os :: raw :: c_int ; } extern "C" { + /// \brief List all the globally registered function name +/// \param out_size The number of functions +/// \param out_array The array of function names. +/// \return 0 when success, -1 when failure happens + pub fn TVMFuncListGlobalNames ( out_size : * mut :: std :: os :: raw :: c_int , out_array : * mut * mut * const :: std :: os :: raw :: c_char ) -> :: std :: os :: raw :: c_int ; } extern "C" { + /// \brief Allocate a nd-array's memory, +/// including space of shape, of given spec. +/// +/// \param shape The shape of the array, the data content will be copied to out +/// \param ndim The number of dimension of the array. +/// \param dtype_code The type code of the dtype +/// \param dtype_bits The number of bits of dtype +/// \param dtype_lanes The number of lanes in the dtype. +/// \param device_type The device type of context +/// \param device_id The device id of context. +/// \param out The output handle. +/// \return 0 when success, -1 when failure happens + pub fn TVMArrayAlloc ( shape : * const tvm_index_t , ndim : :: std :: os :: raw :: c_int , dtype_code : :: std :: os :: raw :: c_int , dtype_bits : :: std :: os :: raw :: c_int , dtype_lanes : :: std :: os :: raw :: c_int , device_type : :: std :: os :: raw :: c_int , device_id : :: std :: os :: raw :: c_int , out : * mut TVMArrayHandle ) -> :: std :: os :: raw :: c_int ; } extern "C" { + /// \brief Free the TVM Array. +/// \param handle The array handle to be freed. +/// \return 0 when success, -1 when failure happens + pub fn TVMArrayFree ( handle : TVMArrayHandle ) -> :: std :: os :: raw :: c_int ; } extern "C" { + /// \brief Copy array data from CPU byte array. +/// \param handle The array handle. +/// \param data the data pointer +/// \param nbytes The number of bytes to copy. +/// \return 0 when success, -1 when failure happens + pub fn TVMArrayCopyFromBytes ( handle : TVMArrayHandle , data : * mut :: std :: os :: raw :: c_void , nbytes : usize ) -> :: std :: os :: raw :: c_int ; } extern "C" { + /// \brief Copy array data to CPU byte array. +/// \param handle The array handle. +/// \param data the data pointer +/// \param nbytes The number of bytes to copy. +/// \return 0 when success, -1 when failure happens + pub fn TVMArrayCopyToBytes ( handle : TVMArrayHandle , data : * mut :: std :: os :: raw :: c_void , nbytes : usize ) -> :: std :: os :: raw :: c_int ; } extern "C" { + /// \brief Copy the array, both from and to must be valid during the copy. +/// \param from The array to be copied from. +/// \param to The target space. +/// \param stream The stream where the copy happens, can be NULL. +/// \return 0 when success, -1 when failure happens + pub fn TVMArrayCopyFromTo ( from : TVMArrayHandle , to : TVMArrayHandle , stream : TVMStreamHandle ) -> :: std :: os :: raw :: c_int ; } extern "C" { + /// \brief Produce an array from the DLManagedTensor that shares data memory +/// with the DLManagedTensor. +/// \param from The source DLManagedTensor. +/// \param out The output array handle. +/// \return 0 when success, -1 when failure happens + pub fn TVMArrayFromDLPack ( from : * mut DLManagedTensor , out : * mut TVMArrayHandle ) -> :: std :: os :: raw :: c_int ; } extern "C" { + /// \brief Produce a DLMangedTensor from the array that shares data memory with +/// the array. +/// \param from The source array. +/// \param out The DLManagedTensor handle. +/// \return 0 when success, -1 when failure happens + pub fn TVMArrayToDLPack ( from : TVMArrayHandle , out : * mut * mut DLManagedTensor ) -> :: std :: os :: raw :: c_int ; } extern "C" { + /// \brief Delete (free) a DLManagedTensor's data. +/// \param dltensor Pointer to the DLManagedTensor. + pub fn TVMDLManagedTensorCallDeleter ( dltensor : * mut DLManagedTensor ) ; } extern "C" { + /// \brief Create a new runtime stream. +/// +/// \param device_type The device type of context +/// \param device_id The device id of context +/// \param out The new stream handle +/// \return 0 when success, -1 when failure happens + pub fn TVMStreamCreate ( device_type : :: std :: os :: raw :: c_int , device_id : :: std :: os :: raw :: c_int , out : * mut TVMStreamHandle ) -> :: std :: os :: raw :: c_int ; } extern "C" { + /// \brief Free a created stream handle. +/// +/// \param device_type The device type of context +/// \param device_id The device id of context +/// \param stream The stream to be freed +/// \return 0 when success, -1 when failure happens + pub fn TVMStreamFree ( device_type : :: std :: os :: raw :: c_int , device_id : :: std :: os :: raw :: c_int , stream : TVMStreamHandle ) -> :: std :: os :: raw :: c_int ; } extern "C" { + /// \brief Set the runtime stream of current thread to be stream. +/// The subsequent calls to the same device_type +/// will use the setted stream handle. +/// The specific type of stream is runtime device dependent. +/// +/// \param device_type The device type of context +/// \param device_id The device id of context. +/// \param handle The stream handle. +/// \return 0 when success, -1 when failure happens + pub fn TVMSetStream ( device_type : :: std :: os :: raw :: c_int , device_id : :: std :: os :: raw :: c_int , handle : TVMStreamHandle ) -> :: std :: os :: raw :: c_int ; } extern "C" { + /// \brief Wait until all computations on stream completes. +/// +/// \param device_type The device type of context +/// \param device_id The device id of context. +/// \param stream The stream to be synchronized. +/// \return 0 when success, -1 when failure happens + pub fn TVMSynchronize ( device_type : :: std :: os :: raw :: c_int , device_id : :: std :: os :: raw :: c_int , stream : TVMStreamHandle ) -> :: std :: os :: raw :: c_int ; } extern "C" { + /// \brief Synchronize two streams of execution. +/// +/// \param device_type The device type of context +/// \param device_id The device id of context +/// \param src The source stream to synchronize. +/// \param dst The destination stream to synchronize. +/// \return 0 when success, -1 when failure happens + pub fn TVMStreamStreamSynchronize ( device_type : :: std :: os :: raw :: c_int , device_id : :: std :: os :: raw :: c_int , src : TVMStreamHandle , dst : TVMStreamHandle ) -> :: std :: os :: raw :: c_int ; } \ No newline at end of file diff --git a/rust/common/src/errors.rs b/rust/common/src/errors.rs index a81fab9f8c8f..4af4e3e38720 100644 --- a/rust/common/src/errors.rs +++ b/rust/common/src/errors.rs @@ -1,15 +1,42 @@ //! Error types for `TVMArgValue` and `TVMRetValue` conversions. +static TYPE_CODE_STRS: [&str; 15] = [ + "int", + "uint", + "float", + "handle", + "null", + "TVMType", + "TVMContext", + "ArrayHandle", + "NodeHandle", + "ModuleHandle", + "FuncHandle", + "str", + "bytes", + "NDArrayContainer", + "ExtBegin", +]; +fn type_code_to_string(type_code: &i64) -> String { + TYPE_CODE_STRS[*type_code as usize].to_string() +} + error_chain! { errors { - TryFromTVMArgValueError(expected: String, actual: String) { - description("mismatched types while converting from TVMArgValue") - display("expected `{}` but given `{}`", expected, actual) - } + // TryFromTVMArgValueError(expected: String, actual: String) { + // description("mismatched types while converting from TVMArgValue") + // display("expected `{}` but given `{}`", expected, actual) + // } + // + // TryFromTVMRetValueError(expected: String, actual: String) { + // description("mismatched types while downcasting TVMRetValue") + // display("invalid downcast: expected `{}` but given `{}`", expected, actual) + // } - TryFromTVMRetValueError(expected: String, actual: String) { - description("mismatched types while downcasting TVMRetValue") - display("invalid downcast: expected `{}` but given `{}`", expected, actual) + TryFromTVMRetValueError(expected_type: String, actual_type_code: i64) { + description("mismatched types while downcasting TVMRetValue") + display("invalid downcast: expected `{}` but was `{}`", + expected_type, type_code_to_string(actual_type_code)) } } } diff --git a/rust/common/src/lib.rs b/rust/common/src/lib.rs index ad4c4f23579e..bb1034c20d73 100644 --- a/rust/common/src/lib.rs +++ b/rust/common/src/lib.rs @@ -13,27 +13,18 @@ extern crate error_chain; pub mod ffi { #![allow(non_camel_case_types, non_snake_case, non_upper_case_globals, unused)] - #[cfg(feature = "frontend")] - pub extern crate tvm_sys as ts; + use std::os::raw::{c_char, c_int, c_void}; - #[cfg(feature = "runtime")] - pub mod runtime { - use std::os::raw::{c_char, c_int, c_void}; + include!(concat!(env!("CARGO_MANIFEST_DIR"), "/src/c_runtime_api.rs")); - include!(concat!(env!("CARGO_MANIFEST_DIR"), "/src/c_runtime_api.rs")); - - pub type BackendPackedCFunc = extern "C" fn( - args: *const TVMValue, - type_codes: *const c_int, - num_args: c_int, - ) -> c_int; - } + pub type BackendPackedCFunc = + extern "C" fn(args: *const TVMValue, type_codes: *const c_int, num_args: c_int) -> c_int; } pub mod errors; -pub mod ty; +pub mod packed_func; pub mod value; pub use errors::*; -pub use ty::TVMTypeCode; -pub use value::{TVMArgValue, TVMRetValue, TVMValue}; +pub use ffi::TVMType; +pub use packed_func::{TVMArgValue, TVMRetValue}; diff --git a/rust/common/src/packed_func.rs b/rust/common/src/packed_func.rs new file mode 100644 index 000000000000..cef5a4dba826 --- /dev/null +++ b/rust/common/src/packed_func.rs @@ -0,0 +1,329 @@ +use std::{any::Any, convert::TryFrom, marker::PhantomData, os::raw::c_void}; + +use crate::{errors::*, ffi::*, value::*}; + +pub type PackedFunc = Box TVMRetValue + Send + Sync>; + +/// Calls a packed function and returns a `TVMRetValue`. +/// +/// # Example +/// +/// `call_packed!(my_tvm_func, &mut arg1, &mut arg2)` +#[macro_export] +macro_rules! call_packed { + ($fn:expr, $($args:expr),+) => { + $fn(&[$($args.into(),)+]) + }; + ($fn:expr) => { + $fn(&Vec::new()) + }; +} + +/// A borrowed TVMPODValue. Can be constructed using `into()` but the preferred way +/// to obtain a `TVMArgValue` is automatically via `call_packed!`. +#[derive(Clone, Copy)] +pub struct TVMArgValue<'a> { + pub _lifetime: PhantomData<&'a ()>, + pub value: TVMValue, + pub type_code: i64, +} + +impl<'a> TVMArgValue<'a> { + pub fn new(value: TVMValue, type_code: i64) -> Self { + TVMArgValue { + _lifetime: PhantomData, + value: value, + type_code: type_code, + } + } +} + +#[macro_export] +macro_rules! ensure_type { + ($val:ident, $actual:expr) => { + ensure!( + $val.type_code == $actual as i64, + "Could not downcast value. Expected type code `{}`, got `{}`", + $actual, + $val.type_code + ); + }; +} + +/// Creates a conversion to a `TVMArgValue` for a primitive type and DLDataTypeCode. +macro_rules! impl_prim_tvm_arg { + ($type:ty, $field:ident, $code:expr, $as:ty) => { + impl<'a> From<$type> for TVMArgValue<'a> { + fn from(val: $type) -> Self { + TVMArgValue { + value: TVMValue { $field: val as $as }, + type_code: $code as i64, + _lifetime: PhantomData, + } + } + } + impl<'a> From<&'a $type> for TVMArgValue<'a> { + fn from(val: &'a $type) -> Self { + TVMArgValue { + value: TVMValue { + $field: val.to_owned() as $as, + }, + type_code: $code as i64, + _lifetime: PhantomData, + } + } + } + impl<'a> TryFrom> for $type { + type Error = Error; + fn try_from(val: TVMArgValue<'a>) -> Result { + ensure_type!(val, $code); + Ok(unsafe { val.value.$field as $type }) + } + } + }; + ($type:ty, $field:ident, $code:expr) => { + impl_prim_tvm_arg!($type, $field, $code, $type); + }; + ($type:ty,v_int64) => { + impl_prim_tvm_arg!($type, v_int64, DLDataTypeCode_kDLInt, i64); + }; + ($type:ty,v_float64) => { + impl_prim_tvm_arg!($type, v_float64, DLDataTypeCode_kDLFloat, f64); + }; +} + +impl_prim_tvm_arg!(f32, v_float64); +impl_prim_tvm_arg!(f64, v_float64); +impl_prim_tvm_arg!(i8, v_int64); +impl_prim_tvm_arg!(u8, v_int64); +impl_prim_tvm_arg!(i32, v_int64); +impl_prim_tvm_arg!(u32, v_int64); +impl_prim_tvm_arg!(i64, v_int64); +impl_prim_tvm_arg!(u64, v_int64); +impl_prim_tvm_arg!(isize, v_int64); +impl_prim_tvm_arg!(usize, v_int64); + +impl<'a> From<&std::ffi::CString> for TVMArgValue<'a> { + fn from(string: &std::ffi::CString) -> Self { + TVMArgValue { + value: TVMValue { + v_handle: string.as_ptr() as *const _ as *mut c_void, + }, + type_code: TVMTypeCode_kStr as i64, + _lifetime: PhantomData, + } + } +} + +/// Creates a conversion to a `TVMArgValue` for an object handle. +impl<'a, T> From<*const T> for TVMArgValue<'a> { + fn from(ptr: *const T) -> Self { + TVMArgValue { + value: TVMValue { + v_handle: ptr as *mut T as *mut c_void, + }, + type_code: TVMTypeCode_kArrayHandle as i64, + _lifetime: PhantomData, + } + } +} + +/// Creates a conversion to a `TVMArgValue` for a mutable object handle. +impl<'a, T> From<*mut T> for TVMArgValue<'a> { + fn from(ptr: *mut T) -> Self { + TVMArgValue { + value: TVMValue { + v_handle: ptr as *mut c_void, + }, + type_code: TVMTypeCode_kHandle as i64, + _lifetime: PhantomData, + } + } +} + +impl<'a> From<&'a mut DLTensor> for TVMArgValue<'a> { + fn from(arr: &'a mut DLTensor) -> Self { + TVMArgValue { + value: TVMValue { + v_handle: arr as *mut _ as *mut c_void, + }, + type_code: TVMTypeCode_kArrayHandle as i64, + _lifetime: PhantomData, + } + } +} + +impl<'a> From<&'a DLTensor> for TVMArgValue<'a> { + fn from(arr: &'a DLTensor) -> Self { + TVMArgValue { + value: TVMValue { + v_handle: arr as *const _ as *mut DLTensor as *mut c_void, + }, + type_code: TVMTypeCode_kArrayHandle as i64, + _lifetime: PhantomData, + } + } +} + +impl<'a, 'v> TryFrom<&'a TVMArgValue<'v>> for TVMType { + type Error = Error; + fn try_from(arg: &'a TVMArgValue<'v>) -> Result { + ensure_type!(arg, TVMTypeCode_kTVMType); + Ok(unsafe { arg.value.v_type.into() }) + } +} + +/// An owned TVMPODValue. Can be converted from a variety of primitive and object types. +/// Can be downcasted using `try_from` if it contains the desired type. +/// +/// # Example +/// +/// ``` +/// let a = 42u32; +/// let b: i64 = TVMRetValue::from(a).try_into().unwrap(); +/// +/// let s = "hello, world!"; +/// let t: TVMRetValue = s.into(); +/// assert_eq!(String::try_from(t).unwrap(), s); +/// ``` +pub struct TVMRetValue { + /// A primitive return value, if any. + pub prim_value: u64, + /// An object return value, if any. + pub box_value: Box, + /// The DLDataTypeCode which determines whether `prim_value` or `box_value` is in use. + pub type_code: i64, +} + +impl TVMRetValue { + pub fn from_tvm_value(value: TVMValue, type_code: i64) -> Self { + unsafe { + Self { + prim_value: match type_code { + 0 | 1 => value.v_int64 as u64, + 2 => value.v_float64 as u64, + 3 | 7 | 8 | 9 | 10 => value.v_handle as u64, + 11 | 12 => value.v_str as u64, + _ => 0, + } as u64, + box_value: box (), + type_code: type_code, + } + } + } + + pub fn into_tvm_value(self) -> (TVMValue, TVMTypeCode) { + let val = match self.type_code { + 0 | 1 => TVMValue { + v_int64: self.prim_value.clone() as i64, + }, + 2 => TVMValue { + v_float64: self.prim_value.clone() as f64, + }, + 3 | 7 | 8 | 9 | 10 | 13 => TVMValue { + v_handle: Box::into_raw(self.box_value) as *mut c_void, + }, + 11 | 12 => TVMValue { + v_str: Box::into_raw(self.box_value) as *const _, + }, + _ => unreachable!(), + }; + (val, self.type_code as TVMTypeCode) + } +} + +impl Default for TVMRetValue { + fn default() -> Self { + TVMRetValue { + prim_value: 0, + box_value: box (), + type_code: 0, + } + } +} + +macro_rules! impl_prim_ret_value { + ($type:ty, $code:expr) => { + impl From<$type> for TVMRetValue { + fn from(val: $type) -> Self { + TVMRetValue { + prim_value: val as u64, + box_value: box (), + type_code: $code, + } + } + } + impl TryFrom for $type { + type Error = Error; + fn try_from(ret: TVMRetValue) -> Result<$type> { + if ret.type_code == $code { + Ok(ret.prim_value as $type) + } else { + bail!(ErrorKind::TryFromTVMRetValueError( + stringify!($type).to_string(), + ret.type_code + )) + } + } + } + }; +} + +macro_rules! impl_boxed_ret_value { + ($type:ty, $code:expr) => { + impl From<$type> for TVMRetValue { + fn from(val: $type) -> Self { + TVMRetValue { + prim_value: 0, + box_value: box val, + type_code: $code, + } + } + } + impl TryFrom for $type { + type Error = Error; + fn try_from(ret: TVMRetValue) -> Result<$type> { + if let Ok(val) = ret.box_value.downcast::<$type>() { + Ok(*val) + } else { + bail!(ErrorKind::TryFromTVMRetValueError( + stringify!($type).to_string(), + ret.type_code + )) + } + } + } + }; +} + +impl_prim_ret_value!(i8, 0); +impl_prim_ret_value!(u8, 1); +impl_prim_ret_value!(i16, 0); +impl_prim_ret_value!(u16, 1); +impl_prim_ret_value!(i32, 0); +impl_prim_ret_value!(u32, 1); +impl_prim_ret_value!(f32, 2); +impl_prim_ret_value!(i64, 0); +impl_prim_ret_value!(u64, 1); +impl_prim_ret_value!(f64, 2); +impl_prim_ret_value!(isize, 0); +impl_prim_ret_value!(usize, 1); +impl_boxed_ret_value!(String, 11); + +// @see `WrapPackedFunc` in `llvm_module.cc`. +pub(super) fn wrap_backend_packed_func(func: BackendPackedCFunc) -> PackedFunc { + box move |args: &[TVMArgValue]| { + func( + args.iter() + .map(|ref arg| arg.value) + .collect::>() + .as_ptr(), + args.iter() + .map(|ref arg| arg.type_code as i32) + .collect::>() + .as_ptr() as *const i32, + args.len() as i32, + ); + TVMRetValue::default() + } +} diff --git a/rust/common/src/ty.rs b/rust/common/src/ty.rs index 126bcd44527e..e69de29bb2d1 100644 --- a/rust/common/src/ty.rs +++ b/rust/common/src/ty.rs @@ -1,144 +0,0 @@ -//! This module containes `TVMTypeCode` and `TVMType` with some conversion methods. -//! -//! # Example -//! -//! ``` -//! let dtype = TVMType::from("float"); -//! println!("dtype is: {}", dtype); -//! ``` - -use std::{ - ffi::{CStr, CString}, - fmt::{self, Display, Formatter}, -}; - -/// TVM type codes. -#[repr(u32)] -#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] -pub enum TVMTypeCode { - kDLInt = 0, - kDLUInt = 1, - kDLFloat = 2, - kHandle = 3, - kNull = 4, - kTVMType = 5, - kTVMContext = 6, - kArrayHandle = 7, - kNodeHandle = 8, - kModuleHandle = 9, - kFuncHandle = 10, - kStr = 11, - kBytes = 12, - kNDArrayContainer = 13, -} - -impl Default for TVMTypeCode { - fn default() -> Self { - TVMTypeCode::kDLInt - } -} - -impl From for i64 { - fn from(arg: TVMTypeCode) -> i64 { - match arg { - TVMTypeCode::kDLInt => 0, - TVMTypeCode::kDLUInt => 1, - TVMTypeCode::kDLFloat => 2, - TVMTypeCode::kHandle => 3, - TVMTypeCode::kNull => 4, - TVMTypeCode::kTVMType => 5, - TVMTypeCode::kTVMContext => 6, - TVMTypeCode::kArrayHandle => 7, - TVMTypeCode::kNodeHandle => 8, - TVMTypeCode::kModuleHandle => 9, - TVMTypeCode::kFuncHandle => 10, - TVMTypeCode::kStr => 11, - TVMTypeCode::kBytes => 12, - TVMTypeCode::kNDArrayContainer => 13, - } - } -} - -impl Into for i64 { - fn into(self) -> TVMTypeCode { - match self { - 0 => TVMTypeCode::kDLInt, - 1 => TVMTypeCode::kDLUInt, - 2 => TVMTypeCode::kDLFloat, - 3 => TVMTypeCode::kHandle, - 4 => TVMTypeCode::kNull, - 5 => TVMTypeCode::kTVMType, - 6 => TVMTypeCode::kTVMContext, - 7 => TVMTypeCode::kArrayHandle, - 8 => TVMTypeCode::kNodeHandle, - 9 => TVMTypeCode::kModuleHandle, - 10 => TVMTypeCode::kFuncHandle, - 11 => TVMTypeCode::kStr, - 12 => TVMTypeCode::kBytes, - 13 => TVMTypeCode::kNDArrayContainer, - _ => unreachable!(), - } - } -} - -impl Display for TVMTypeCode { - fn fmt(&self, f: &mut Formatter) -> fmt::Result { - write!( - f, - "{}", - match self { - TVMTypeCode::kDLInt => "int", - TVMTypeCode::kDLUInt => "uint", - TVMTypeCode::kDLFloat => "float", - TVMTypeCode::kHandle => "handle", - TVMTypeCode::kNull => "null", - TVMTypeCode::kTVMType => "TVM type", - TVMTypeCode::kTVMContext => "TVM context", - TVMTypeCode::kArrayHandle => "Array handle", - TVMTypeCode::kNodeHandle => "Node handle", - TVMTypeCode::kModuleHandle => "Module handle", - TVMTypeCode::kFuncHandle => "Function handle", - TVMTypeCode::kStr => "string", - TVMTypeCode::kBytes => "bytes", - TVMTypeCode::kNDArrayContainer => "ndarray container", - } - ) - } -} - -macro_rules! impl_prim_type { - ($type:ty, $variant:ident) => { - impl<'a> From<&'a $type> for TVMTypeCode { - fn from(_arg: &$type) -> Self { - TVMTypeCode::$variant - } - } - - impl<'a> From<&'a mut $type> for TVMTypeCode { - fn from(_arg: &mut $type) -> Self { - TVMTypeCode::$variant - } - } - }; -} - -impl_prim_type!(usize, kDLInt); -impl_prim_type!(i64, kDLInt); -impl_prim_type!(i32, kDLInt); -impl_prim_type!(i16, kDLInt); -impl_prim_type!(i8, kDLInt); - -impl_prim_type!(u64, kDLUInt); -impl_prim_type!(u32, kDLUInt); -impl_prim_type!(u16, kDLUInt); -impl_prim_type!(u8, kDLUInt); - -impl_prim_type!(f64, kDLFloat); -impl_prim_type!(f32, kDLFloat); - -impl_prim_type!(str, kStr); -impl_prim_type!(CStr, kStr); -impl_prim_type!(String, kStr); -impl_prim_type!(CString, kStr); - -impl_prim_type!([u8], kBytes); diff --git a/rust/common/src/value.rs b/rust/common/src/value.rs index 6da8b27e8660..7fc82c6af665 100644 --- a/rust/common/src/value.rs +++ b/rust/common/src/value.rs @@ -1,559 +1,92 @@ -//! This module provides the the wrapped `TVMValue`, `TVMArgValue` and `TVMRetValue` -//! required for using TVM functions. +use crate::ffi::*; -use std::{ - any::Any, - convert::TryFrom, - ffi::{CStr, CString}, - fmt::{self, Debug, Formatter}, - marker::PhantomData, - mem, - ops::Deref, - os::raw::{c_char, c_void}, -}; - -#[cfg(feature = "runtime")] -use ffi::runtime::TVMValue as _TVMValue; - -#[cfg(feature = "frontend")] -use ffi::ts::TVMValue as _TVMValue; - -use errors::*; - -use ty::TVMTypeCode; - -/// Wrapped TVMValue type. -#[derive(Clone, Copy)] -pub struct TVMValue { - pub inner: _TVMValue, -} - -impl TVMValue { - /// Creates TVMValue from the raw part. - pub fn new(inner: _TVMValue) -> Self { - TVMValue { inner } - } - - pub(crate) fn into_raw(self) -> _TVMValue { - self.inner - } -} - -impl Debug for TVMValue { - fn fmt(&self, f: &mut Formatter) -> fmt::Result { - unsafe { - write!( - f, - "TVMValue: [v_int64: {:?}], [v_float64: {:?}], [v_handle: {:?}],\ - [v_str: {:?}]", - self.inner.v_int64, self.inner.v_float64, self.inner.v_handle, self.inner.v_str - ) +impl TVMType { + fn new(type_code: u8, bits: u8, lanes: u16) -> Self { + Self { + code: type_code, + bits, + lanes, } } } -impl Deref for TVMValue { - type Target = _TVMValue; - fn deref(&self) -> &Self::Target { - &self.inner - } -} - -macro_rules! impl_prim_val { - ($type:ty, $field:ident, $cast:ty) => { - impl From<$type> for TVMValue { - fn from(arg: $type) -> Self { - let inner = _TVMValue { - $field: arg as $cast, - }; - Self::new(inner) - } - } - - impl<'a> From<&'a $type> for TVMValue { - fn from(arg: &$type) -> Self { - let inner = _TVMValue { - $field: *arg as $cast, - }; - Self::new(inner) - } - } - - impl<'a> From<&'a mut $type> for TVMValue { - fn from(arg: &mut $type) -> Self { - let inner = _TVMValue { - $field: *arg as $cast, - }; - Self::new(inner) - } - } - - impl TryFrom for $type { - type Error = Error; - fn try_from(val: TVMValue) -> Result { - Ok(unsafe { val.inner.$field as $type }) - } - } - - impl<'a> TryFrom<&'a TVMValue> for $type { - type Error = Error; - fn try_from(val: &TVMValue) -> Result { - Ok(unsafe { val.into_raw().$field as $type }) - } +/// Implements TVMType conversion from `&str` of general format `{dtype}{bits}x{lanes}` +/// such as "int32", "float32" or with lane "float32x1". +impl<'a> From<&'a str> for TVMType { + fn from(type_str: &'a str) -> Self { + if type_str == "bool" { + return TVMType::new(1, 1, 1); } - impl<'a> TryFrom<&'a mut TVMValue> for $type { - type Error = Error; - fn try_from(val: &mut TVMValue) -> Result { - Ok(unsafe { val.into_raw().$field as $type }) + let mut type_lanes = type_str.split("x"); + let typ = type_lanes.next().expect("Missing dtype"); + let lanes = type_lanes + .next() + .map(|l| u16::from_str_radix(l, 10).expect(&format!("Bad dtype lanes: {}", l))) + .unwrap_or(1); + let (type_name, bits) = match typ.find(char::is_numeric) { + Some(idx) => { + let (name, bits_str) = typ.split_at(idx); + ( + name, + u8::from_str_radix(bits_str, 10) + .expect(&format!("Bad dtype bits: {}", bits_str)), + ) } - } - }; -} - -impl_prim_val!(isize, v_int64, i64); -impl_prim_val!(i64, v_int64, i64); -impl_prim_val!(i32, v_int64, i64); -impl_prim_val!(i16, v_int64, i64); -impl_prim_val!(i8, v_int64, i64); -impl_prim_val!(usize, v_int64, i64); -impl_prim_val!(u64, v_int64, i64); -impl_prim_val!(u32, v_int64, i64); -impl_prim_val!(u16, v_int64, i64); -impl_prim_val!(u8, v_int64, i64); - -impl_prim_val!(f64, v_float64, f64); -impl_prim_val!(f32, v_float64, f64); - -impl<'a> From<&'a str> for TVMValue { - fn from(arg: &str) -> TVMValue { - let arg = CString::new(arg).unwrap(); - let inner = _TVMValue { - v_str: arg.as_ptr() as *const c_char, - }; - mem::forget(arg); - Self::new(inner) - } -} - -impl<'a> From<&'a String> for TVMValue { - fn from(arg: &String) -> TVMValue { - let arg = CString::new(arg.as_bytes()).unwrap(); - let inner = _TVMValue { - v_str: arg.as_ptr() as *const c_char, - }; - mem::forget(arg); - Self::new(inner) - } -} - -impl<'a> From<&'a CString> for TVMValue { - fn from(arg: &CString) -> TVMValue { - let arg = arg.to_owned(); - let inner = _TVMValue { - v_str: arg.as_ptr() as *const c_char, + None => (typ, 32), }; - mem::forget(arg); - Self::new(inner) - } -} -impl<'a> From<&'a [u8]> for TVMValue { - fn from(arg: &[u8]) -> TVMValue { - let arg = arg.to_owned(); - let inner = _TVMValue { - v_handle: &arg as *const _ as *mut c_void, + let type_code = match type_name { + "int" => 0, + "uint" => 1, + "float" => 2, + "handle" => 3, + _ => unimplemented!(), }; - mem::forget(arg); - Self::new(inner) - } -} - -/// Captures both `TVMValue` and `TVMTypeCode` needed for TVM function. -/// The preferred way to obtain a `TVMArgValue` is automatically via `call_packed!`. -/// or in the frontend crate, with `function::Builder`. Checkout the methods for conversions. -/// -/// ## Example -/// -/// ``` -/// let s = "hello".to_string(); -/// let arg = TVMArgValue::from(&s); -/// let tvm: String = arg.try_into().unwrap(); -/// assert_eq!(arg, s); -/// ``` -#[derive(Debug, Clone, Copy)] -pub struct TVMArgValue<'a> { - /// The wrapped TVMValue - pub value: TVMValue, - /// The matching type code. - pub type_code: TVMTypeCode, - /// This is only exposed to runtime and frontend crates and is not meant to be used directly. - pub lifetime: PhantomData<&'a ()>, -} -impl<'a> TVMArgValue<'a> { - pub fn new(value: TVMValue, type_code: TVMTypeCode) -> Self { - TVMArgValue { - value: value, - type_code: type_code, - lifetime: PhantomData, - } + TVMType::new(type_code, bits, lanes) } } -impl<'a, 'b> TryFrom<&'b TVMArgValue<'a>> for i64 { - type Error = Error; - fn try_from(arg: &TVMArgValue<'a>) -> Result { - if (arg.type_code == TVMTypeCode::kDLInt) - | (arg.type_code == TVMTypeCode::kDLUInt) - | (arg.type_code == TVMTypeCode::kNull) - { - Ok(unsafe { arg.value.inner.v_int64 }) - } else { - bail!(ErrorKind::TryFromTVMArgValueError( - stringify!(i64).to_string(), - arg.type_code.to_string() - )) +impl std::fmt::Display for TVMType { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + if self.bits == 1 && self.lanes == 1 { + return write!(f, "bool"); } - } -} - -impl<'a, 'b> TryFrom<&'b TVMArgValue<'a>> for f64 { - type Error = Error; - fn try_from(arg: &TVMArgValue<'a>) -> Result { - if arg.type_code == TVMTypeCode::kDLFloat { - Ok(unsafe { arg.value.inner.v_float64 }) - } else { - bail!(ErrorKind::TryFromTVMArgValueError( - stringify!(f64).to_string(), - arg.type_code.to_string() - )) + let mut type_str = match self.code { + 0 => "int", + 1 => "uint", + 2 => "float", + 4 => "handle", + _ => "unknown", } - } -} + .to_string(); -impl<'a, 'b> TryFrom<&'b TVMArgValue<'a>> for String { - type Error = Error; - fn try_from(arg: &TVMArgValue<'a>) -> Result { - if arg.type_code == TVMTypeCode::kStr { - let ret_str = unsafe { - match CStr::from_ptr(arg.value.inner.v_str).to_str() { - Ok(s) => s, - Err(_) => "Invalid UTF-8 message", - } - }; - Ok(ret_str.to_string()) - } else { - bail!(ErrorKind::TryFromTVMArgValueError( - stringify!(String).to_string(), - arg.type_code.to_string() - )) + type_str += &self.bits.to_string(); + if self.lanes > 1 { + type_str += &format!("x{}", self.lanes); } + f.write_str(&type_str) } } -/// Main way to create a TVMArgValue from suported Rust values. -impl<'b, 'a: 'b, T: 'b + ?Sized> From<&'b T> for TVMArgValue<'a> -where - TVMValue: From<&'b T>, - TVMTypeCode: From<&'b T>, -{ - fn from(arg: &'b T) -> Self { - TVMArgValue::new(TVMValue::from(arg), TVMTypeCode::from(arg)) - } -} - -/// Creates a conversion to a `TVMArgValue` for an object handle. -impl<'a, T> From<*const T> for TVMArgValue<'a> { - fn from(ptr: *const T) -> Self { - let value = TVMValue::new(_TVMValue { - v_handle: ptr as *mut T as *mut c_void, - }); - - TVMArgValue::new(value, TVMTypeCode::kArrayHandle) - } -} - -/// Creates a conversion to a `TVMArgValue` for a mutable object handle. -impl<'a, T> From<*mut T> for TVMArgValue<'a> { - fn from(ptr: *mut T) -> Self { - let value = TVMValue::new(_TVMValue { - v_handle: ptr as *mut c_void, - }); - - TVMArgValue::new(value, TVMTypeCode::kHandle) - } -} - -/// An owned version of TVMPODValue. It can be converted from varieties of -/// primitive and object types. -/// It can be downcasted using `try_from` if it contains the desired type. -/// -/// # Example -/// -/// ``` -/// let a = 42u32; -/// let b: i64 = TVMRetValue::from(a).try_into().unwrap(); -/// -/// let s = "hello, world!"; -/// let t: TVMRetValue = s.into(); -/// assert_eq!(String::try_from(t).unwrap(), s); -/// ``` -pub struct TVMRetValue { - /// A primitive return value, if any. - pub prim_value: usize, - /// An object return value, if any. - pub box_value: Box, - pub type_code: TVMTypeCode, -} - -impl TVMRetValue { - fn new(prim_value: usize, box_value: Box, type_code: TVMTypeCode) -> Self { - Self { - prim_value, - box_value, - type_code, - } - } - - /// unsafe function to create `TVMRetValue` from `TVMValue` and - /// its matching `TVMTypeCode`. - pub unsafe fn from_tvm_value(value: TVMValue, type_code: TVMTypeCode) -> Self { - let value = value.into_raw(); - match type_code { - TVMTypeCode::kDLInt | TVMTypeCode::kDLUInt => { - Self::new(value.v_int64 as usize, box (), type_code) - } - TVMTypeCode::kDLFloat => Self::new(value.v_float64 as usize, box (), type_code), - TVMTypeCode::kHandle - | TVMTypeCode::kArrayHandle - | TVMTypeCode::kNodeHandle - | TVMTypeCode::kModuleHandle - | TVMTypeCode::kFuncHandle => { - Self::new(value.v_handle as usize, box value.v_handle, type_code) - } - TVMTypeCode::kStr | TVMTypeCode::kBytes => { - Self::new(value.v_str as usize, box (value.v_str), type_code) - } - _ => Self::new(0usize, box (), type_code), - } - } - - /// Returns the underlying `TVMValue` and `TVMTypeCode`. - pub fn into_tvm_value(self) -> (TVMValue, TVMTypeCode) { - let val = match self.type_code { - TVMTypeCode::kDLInt | TVMTypeCode::kDLUInt => TVMValue::new(_TVMValue { - v_int64: self.prim_value as i64, - }), - TVMTypeCode::kDLFloat => TVMValue::new(_TVMValue { - v_float64: self.prim_value as f64, - }), - TVMTypeCode::kHandle - | TVMTypeCode::kArrayHandle - | TVMTypeCode::kNodeHandle - | TVMTypeCode::kModuleHandle - | TVMTypeCode::kFuncHandle - | TVMTypeCode::kNDArrayContainer => TVMValue::new(_TVMValue { - v_handle: self.prim_value as *const c_void as *mut c_void, - }), - TVMTypeCode::kStr | TVMTypeCode::kBytes => TVMValue::new(_TVMValue { - v_str: self.prim_value as *const c_char, - }), - _ => unreachable!(), - }; - (val, self.type_code) - } -} - -impl Default for TVMRetValue { - fn default() -> Self { - TVMRetValue { - prim_value: 0usize, - box_value: box (), - type_code: TVMTypeCode::default(), - } - } -} - -impl Clone for TVMRetValue { - fn clone(&self) -> Self { - match self.type_code { - TVMTypeCode::kDLInt | TVMTypeCode::kDLUInt | TVMTypeCode::kDLFloat => { - Self::new(self.prim_value.clone(), box (), self.type_code.clone()) - } - TVMTypeCode::kHandle - | TVMTypeCode::kArrayHandle - | TVMTypeCode::kNodeHandle - | TVMTypeCode::kModuleHandle - | TVMTypeCode::kFuncHandle - | TVMTypeCode::kNDArrayContainer => Self::new( - self.prim_value.clone(), - box (self.prim_value.clone() as *const c_void as *mut c_void), - self.type_code.clone(), - ), - TVMTypeCode::kStr | TVMTypeCode::kBytes => Self::new( - self.prim_value.clone(), - box (self.prim_value.clone() as *const c_char), - self.type_code.clone(), - ), - _ => unreachable!(), - } - } -} - -impl Debug for TVMRetValue { - fn fmt(&self, f: &mut Formatter) -> fmt::Result { - write!( - f, - "prim_value: {:?}, box_value: {:?}, type_code: {:?}", - self.prim_value, self.prim_value as *const c_void as *mut c_void, self.type_code - ) - } -} - -macro_rules! impl_prim_ret_value { - ($type:ty, $code:expr) => { - impl From<$type> for TVMRetValue { - fn from(val: $type) -> Self { - TVMRetValue { - prim_value: val as usize, - box_value: box (), - type_code: $code, - } - } - } - - impl<'a> From<&'a $type> for TVMRetValue { - fn from(val: &$type) -> Self { - TVMRetValue { - prim_value: *val as usize, - box_value: box (), - type_code: $code, - } - } - } - - impl<'a> From<&'a mut $type> for TVMRetValue { - fn from(val: &mut $type) -> Self { - TVMRetValue { - prim_value: *val as usize, - box_value: box (), - type_code: $code, - } - } - } - - impl TryFrom for $type { - type Error = Error; - fn try_from(ret: TVMRetValue) -> Result<$type> { - if ret.type_code == $code { - Ok(ret.prim_value as $type) - } else { - bail!(ErrorKind::TryFromTVMRetValueError( - stringify!($type).to_string(), - ret.type_code.to_string(), - )) - } +macro_rules! impl_tvm_val_from_pod { + ($field:ident, $ty:ty) => { + impl From<$ty> for TVMValue { + fn from(val: $ty) -> Self { + TVMValue { $field: val } } } }; } -impl_prim_ret_value!(i8, TVMTypeCode::kDLInt); -impl_prim_ret_value!(i16, TVMTypeCode::kDLInt); -impl_prim_ret_value!(i32, TVMTypeCode::kDLInt); -impl_prim_ret_value!(i64, TVMTypeCode::kDLInt); -impl_prim_ret_value!(isize, TVMTypeCode::kDLInt); - -impl_prim_ret_value!(u8, TVMTypeCode::kDLUInt); -impl_prim_ret_value!(u16, TVMTypeCode::kDLUInt); -impl_prim_ret_value!(u32, TVMTypeCode::kDLUInt); -impl_prim_ret_value!(u64, TVMTypeCode::kDLUInt); -impl_prim_ret_value!(usize, TVMTypeCode::kDLUInt); +impl_tvm_val_from_pod!(v_type, TVMType); +impl_tvm_val_from_pod!(v_ctx, TVMContext); -impl_prim_ret_value!(f32, TVMTypeCode::kDLFloat); -impl_prim_ret_value!(f64, TVMTypeCode::kDLFloat); - -macro_rules! impl_ptr_ret_value { - ($type:ty) => { - impl From<$type> for TVMRetValue { - fn from(ptr: $type) -> Self { - TVMRetValue { - prim_value: ptr as usize, - box_value: box (), - type_code: TVMTypeCode::kHandle, - } - } - } - - impl TryFrom for $type { - type Error = Error; - fn try_from(ret: TVMRetValue) -> Result<$type> { - if ret.type_code == TVMTypeCode::kHandle { - Ok(ret.prim_value as $type) - } else { - bail!(ErrorKind::TryFromTVMRetValueError( - stringify!($type).to_string(), - ret.type_code.to_string(), - )) - } - } +impl From for TVMValue { + fn from(dev: DLDeviceType) -> Self { + TVMValue { + v_int64: dev as i64, } - }; -} - -impl_ptr_ret_value!(*const c_void); -impl_ptr_ret_value!(*mut c_void); - -impl From for TVMRetValue { - fn from(val: String) -> Self { - let pval = val.as_ptr() as *const c_char as usize; - let bval = box (val.as_ptr() as *const c_char); - mem::forget(val); - TVMRetValue::new(pval, bval, TVMTypeCode::kStr) - } -} - -impl TryFrom for String { - type Error = Error; - fn try_from(ret: TVMRetValue) -> Result { - // Note: simple downcast doesn't work for function call return values - let ret_str = unsafe { - match CStr::from_ptr(ret.prim_value as *const c_char).to_str() { - Ok(s) => s, - Err(_) => "Invalid UTF-8 message", - } - }; - - Ok(ret_str.to_string()) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use std::convert::TryInto; - - #[test] - fn numeric() { - macro_rules! arg_ret_tests { - ($v:expr; $($ty:ty),+) => {{ - $( - let v = $v as $ty; - let b = TVMRetValue::from(&v); - let b: $ty = b.try_into().unwrap(); - assert_eq!(b, v); - )+ - }}; - } - - arg_ret_tests!(42; i8, i16, i32, i64, f32, f64); - } - - #[test] - fn string() { - let s = "hello".to_string(); - let tvm_arg: String = TVMRetValue::from(s.clone()).try_into().unwrap(); - assert_eq!(tvm_arg, s); } } diff --git a/rust/common/tvm-sys/Cargo.toml b/rust/common/tvm-sys/Cargo.toml deleted file mode 100644 index 117d174b4cbd..000000000000 --- a/rust/common/tvm-sys/Cargo.toml +++ /dev/null @@ -1,9 +0,0 @@ -[package] -name = "tvm-sys" -version = "0.1.0" -authors = ["TVM Contributors"] -license = "Apache-2.0" -description = "Raw C API" - -[build-dependencies] -bindgen = "0.37.4" diff --git a/rust/common/tvm-sys/src/lib.rs b/rust/common/tvm-sys/src/lib.rs deleted file mode 100644 index 15f1ea3a611c..000000000000 --- a/rust/common/tvm-sys/src/lib.rs +++ /dev/null @@ -1,9 +0,0 @@ -#![allow( - non_camel_case_types, - non_snake_case, - non_upper_case_globals, - dead_code, - improper_ctypes -)] - -include!("bindgen.rs"); diff --git a/rust/frontend/Cargo.toml b/rust/frontend/Cargo.toml index db261551e36f..cf09133e2a15 100644 --- a/rust/frontend/Cargo.toml +++ b/rust/frontend/Cargo.toml @@ -19,7 +19,7 @@ error-chain = "0.12.0" lazy_static = "1.1.0" ndarray = "0.12.1" num-traits = "0.2" -tvm-common = { version = "0.1.0", path = "../common/", features = ["frontend"] } +tvm-common = { version = "0.1.0", path = "../common/", features = ["bindings"] } [features] blas = ["ndarray/blas"] diff --git a/rust/frontend/src/bytearray.rs b/rust/frontend/src/bytearray.rs index 395f34c2428d..cfcfa06da25c 100644 --- a/rust/frontend/src/bytearray.rs +++ b/rust/frontend/src/bytearray.rs @@ -5,7 +5,7 @@ use std::os::raw::c_char; -use crate::ts; +use tvm_common::ffi; /// A struct holding TVM byte-array. /// @@ -19,11 +19,11 @@ use crate::ts; /// ``` #[derive(Debug, Clone)] pub struct TVMByteArray { - pub(crate) inner: ts::TVMByteArray, + pub(crate) inner: ffi::TVMByteArray, } impl TVMByteArray { - pub(crate) fn new(barr: ts::TVMByteArray) -> TVMByteArray { + pub(crate) fn new(barr: ffi::TVMByteArray) -> TVMByteArray { TVMByteArray { inner: barr } } @@ -46,7 +46,7 @@ impl TVMByteArray { impl<'a> From<&'a Vec> for TVMByteArray { fn from(arg: &Vec) -> Self { - let barr = ts::TVMByteArray { + let barr = ffi::TVMByteArray { data: arg.as_ptr() as *const c_char, size: arg.len(), }; diff --git a/rust/frontend/src/context.rs b/rust/frontend/src/context.rs index 65e11d82e2d0..3683e2ec3272 100644 --- a/rust/frontend/src/context.rs +++ b/rust/frontend/src/context.rs @@ -18,12 +18,15 @@ //! ``` use std::{ + convert::TryInto, fmt::{self, Display, Formatter}, os::raw::c_void, ptr, }; -use crate::{function, ts, Result}; +use tvm_common::ffi; + +use crate::{function, Result}; /// Device type can be from a supported device name. See the supported devices /// in [TVM](https://github.com/dmlc/tvm). @@ -45,35 +48,35 @@ impl Default for TVMDeviceType { } } -impl From for ts::DLDeviceType { +impl From for ffi::DLDeviceType { fn from(device_type: TVMDeviceType) -> Self { match device_type.0 { - 1 => ts::DLDeviceType_kDLCPU, - 2 => ts::DLDeviceType_kDLGPU, - 3 => ts::DLDeviceType_kDLCPUPinned, - 4 => ts::DLDeviceType_kDLOpenCL, - 7 => ts::DLDeviceType_kDLVulkan, - 8 => ts::DLDeviceType_kDLMetal, - 9 => ts::DLDeviceType_kDLVPI, - 10 => ts::DLDeviceType_kDLROCM, - 12 => ts::DLDeviceType_kDLExtDev, + 1 => ffi::DLDeviceType_kDLCPU, + 2 => ffi::DLDeviceType_kDLGPU, + 3 => ffi::DLDeviceType_kDLCPUPinned, + 4 => ffi::DLDeviceType_kDLOpenCL, + 7 => ffi::DLDeviceType_kDLVulkan, + 8 => ffi::DLDeviceType_kDLMetal, + 9 => ffi::DLDeviceType_kDLVPI, + 10 => ffi::DLDeviceType_kDLROCM, + 12 => ffi::DLDeviceType_kDLExtDev, _ => panic!("device type not found!"), } } } -impl From for TVMDeviceType { - fn from(device_type: ts::DLDeviceType) -> Self { +impl From for TVMDeviceType { + fn from(device_type: ffi::DLDeviceType) -> Self { match device_type { - ts::DLDeviceType_kDLCPU => TVMDeviceType(1), - ts::DLDeviceType_kDLGPU => TVMDeviceType(2), - ts::DLDeviceType_kDLCPUPinned => TVMDeviceType(3), - ts::DLDeviceType_kDLOpenCL => TVMDeviceType(4), - ts::DLDeviceType_kDLVulkan => TVMDeviceType(7), - ts::DLDeviceType_kDLMetal => TVMDeviceType(8), - ts::DLDeviceType_kDLVPI => TVMDeviceType(9), - ts::DLDeviceType_kDLROCM => TVMDeviceType(10), - ts::DLDeviceType_kDLExtDev => TVMDeviceType(12), + ffi::DLDeviceType_kDLCPU => TVMDeviceType(1), + ffi::DLDeviceType_kDLGPU => TVMDeviceType(2), + ffi::DLDeviceType_kDLCPUPinned => TVMDeviceType(3), + ffi::DLDeviceType_kDLOpenCL => TVMDeviceType(4), + ffi::DLDeviceType_kDLVulkan => TVMDeviceType(7), + ffi::DLDeviceType_kDLMetal => TVMDeviceType(8), + ffi::DLDeviceType_kDLVPI => TVMDeviceType(9), + ffi::DLDeviceType_kDLROCM => TVMDeviceType(10), + ffi::DLDeviceType_kDLExtDev => TVMDeviceType(12), _ => panic!("device type not found!"), } } @@ -185,20 +188,20 @@ impl<'a> From<&'a str> for TVMContext { impl TVMContext { /// Checks whether the context exists or not. pub fn exist(&self) -> bool { - let func = function::Function::get("_GetDeviceAttr", true /* is_global */) - .expect("API function always exists"); + let func = function::Function::get("_GetDeviceAttr").expect("API function always exists"); let dt = self.device_type.0 as usize; // `unwrap` is ok here because if there is any error, // if would occure inside `call_packed!` - let ret = call_packed!(func, &dt, &self.device_id, &0) + let ret: u64 = call_packed!(func, &dt, &self.device_id, &0) .unwrap() - .prim_value; + .try_into() + .unwrap(); ret != 0 } /// Synchronize the context stream. pub fn sync(&self) -> Result<()> { - check_call!(ts::TVMSynchronize( + check_call!(ffi::TVMSynchronize( self.device_type.0 as i32, self.device_id as i32, ptr::null_mut() as *mut c_void @@ -212,16 +215,17 @@ macro_rules! impl_device_attrs { $( impl TVMContext { pub fn $attr_name(&self) -> usize { - let func = function::Function::get("_GetDeviceAttr", true /* is_global */) + let func = function::Function::get("_GetDeviceAttr") .expect("API function always exists"); let dt = self.device_type.0 as usize; // `unwrap` is ok here because if there is any error, // if would occur in function call. - let ret = function::Builder::from(func) + function::Builder::from(func) .args(&[dt, self.device_id, $attr_kind]) .invoke() - .unwrap(); - ret.prim_value as usize + .unwrap() + .try_into() + .unwrap() } } )+ @@ -237,8 +241,8 @@ impl_device_attrs!((max_threads_per_block, 1); (multi_processor_count, 7); (max_thread_dimensions, 8)); -impl From for TVMContext { - fn from(ctx: ts::DLContext) -> Self { +impl From for TVMContext { + fn from(ctx: ffi::DLContext) -> Self { TVMContext { device_type: TVMDeviceType::from(ctx.device_type), device_id: ctx.device_id as usize, @@ -246,9 +250,9 @@ impl From for TVMContext { } } -impl From for ts::DLContext { +impl From for ffi::DLContext { fn from(ctx: TVMContext) -> Self { - ts::DLContext { + ffi::DLContext { device_type: ctx.device_type.into(), device_id: ctx.device_id as i32, } diff --git a/rust/frontend/src/errors.rs b/rust/frontend/src/errors.rs index a10f83c4110d..6cef88c7f285 100644 --- a/rust/frontend/src/errors.rs +++ b/rust/frontend/src/errors.rs @@ -40,7 +40,10 @@ error_chain! { ShapeError(rust_ndarray::ShapeError); NulError(ffi::NulError); IntoStringError(ffi::IntoStringError); - CommonError(common_errors::Error); + } + + links { + CommonError(common_errors::Error, common_errors::ErrorKind); } } diff --git a/rust/frontend/src/function.rs b/rust/frontend/src/function.rs index fa6bed141076..d6cb4d227165 100644 --- a/rust/frontend/src/function.rs +++ b/rust/frontend/src/function.rs @@ -15,14 +15,17 @@ use std::{ sync::Mutex, }; -use crate::{ts, ErrorKind, Module, Result, TVMArgValue, TVMRetValue, TVMTypeCode, TVMValue}; +use crate::{ + ffi::{self, TVMValue}, + ErrorKind, Module, Result, TVMArgValue, TVMRetValue, +}; lazy_static! { static ref GLOBAL_FUNCTIONS: Mutex>> = { let mut out_size = 0 as c_int; let name = ptr::null_mut() as *mut c_char; let mut out_array = name as *mut _; - check_call!(ts::TVMFuncListGlobalNames( + check_call!(ffi::TVMFuncListGlobalNames( &mut out_size as *mut _, &mut out_array )); @@ -37,17 +40,14 @@ lazy_static! { } /// Wrapper around TVM function handle which includes `is_global` -/// indicating whether the function is global or not, `is_released` -/// to hint dropping the function handle and `is_cloned` showing +/// indicating whether the function is global or not, and `is_cloned` showing /// not to drop a cloned function from Rust side. /// The value of these fields can be accessed through their respective methods. #[derive(Debug, Hash)] pub struct Function { - pub(crate) handle: ts::TVMFunctionHandle, + pub(crate) handle: ffi::TVMFunctionHandle, // whether the registered function is global or not. is_global: bool, - // whether the function has been dropped from frontend or not. - is_released: bool, // whether the function has been cloned from frontend or not. is_cloned: bool, } @@ -56,29 +56,30 @@ unsafe impl Send for Function {} unsafe impl Sync for Function {} impl Function { - pub(crate) fn new(handle: ts::TVMFunctionHandle, is_global: bool, is_released: bool) -> Self { + pub(crate) fn new(handle: ffi::TVMFunctionHandle) -> Self { Function { handle: handle, - is_global: is_global, - is_released: is_released, + is_global: false, is_cloned: false, } } /// For a given function, it returns a function by name. - pub fn get>(name: S, is_global: bool) -> Option<&'static Function> { + pub fn get>(name: S) -> Option<&'static Function> { let mut globals = GLOBAL_FUNCTIONS.lock().unwrap(); globals.get_mut(name.as_ref()).and_then(|maybe_func| { if maybe_func.is_none() { let name = CString::new(name.as_ref()).unwrap(); - let mut handle = ptr::null_mut() as ts::TVMFunctionHandle; - check_call!(ts::TVMFuncGetGlobal( + let mut handle = ptr::null_mut() as ffi::TVMFunctionHandle; + check_call!(ffi::TVMFuncGetGlobal( name.as_ptr() as *const c_char, &mut handle as *mut _ )); - maybe_func.replace(Function::new( - handle, is_global, false, /* is_released */ - )); + maybe_func.replace(Function { + handle: handle, + is_global: true, + is_cloned: false, + }); } unsafe { std::mem::transmute::, Option<&'static Function>>( @@ -89,7 +90,7 @@ impl Function { } /// Returns the underlying TVM function handle. - pub fn handle(&self) -> ts::TVMFunctionHandle { + pub fn handle(&self) -> ffi::TVMFunctionHandle { self.handle } @@ -98,12 +99,6 @@ impl Function { self.is_global } - /// Returns `true` if the underlying TVM function has been released - /// from the frontend and `false` otherwise. - pub fn is_released(&self) -> bool { - self.is_released - } - /// Returns `true` if the underlying TVM function has been cloned /// from the frontend and `false` otherwise. pub fn is_cloned(&self) -> bool { @@ -113,24 +108,18 @@ impl Function { impl Clone for Function { fn clone(&self) -> Function { - if !self.is_released && !self.is_cloned { - Self { - handle: self.handle, - is_global: self.is_global, - is_released: self.is_released, - is_cloned: true, - } - } else { - Function::new(self.handle, self.is_global, self.is_released) + Self { + handle: self.handle, + is_global: self.is_global, + is_cloned: true, } } } impl Drop for Function { fn drop(&mut self) { - if !self.is_released && !self.is_global && !self.is_cloned { - check_call!(ts::TVMFuncFree(self.handle)); - self.is_released = true; + if !self.is_global && !self.is_cloned { + check_call!(ffi::TVMFuncFree(self.handle)); } } } @@ -138,17 +127,17 @@ impl Drop for Function { /// Function builder in order to create and call functions. /// /// *Note:* Currently TVM functions accept *at most* one return value. -#[derive(Debug, Clone, Default)] +#[derive(Default)] pub struct Builder<'a, 'm> { pub func: Option<&'m Function>, - pub arg_buf: Option]>>, + pub arg_buf: Vec>, pub ret_buf: Option, } impl<'a, 'm> Builder<'a, 'm> { pub fn new( func: Option<&'m Function>, - arg_buf: Option]>>, + arg_buf: Vec>, ret_buf: Option, ) -> Self { Self { @@ -158,57 +147,40 @@ impl<'a, 'm> Builder<'a, 'm> { } } - pub fn get_function(&mut self, name: &'m str, is_global: bool) -> &mut Self { - self.func = Function::get(name, is_global); + pub fn get_function(&mut self, name: &'m str) -> &mut Self { + self.func = Function::get(name); self } /// Pushes a [`TVMArgValue`] into the function argument buffer. - pub fn arg<'b, T: ?Sized>(&mut self, arg: &'b T) -> &mut Self + pub fn arg(&mut self, arg: &'a T) -> &mut Self where - TVMValue: From<&'b T>, - TVMTypeCode: From<&'b T>, + TVMArgValue<'a>: From<&'a T>, { - let tvm_arg = TVMArgValue::from(arg); - if self.arg_buf.is_none() { - self.arg_buf = Some(Box::new([tvm_arg])); - } else { - let new_arg_buf = self.arg_buf.take().map(|bbuf| { - let mut new_arg_buf = Vec::from(bbuf); - new_arg_buf.push(tvm_arg); - let new_len = new_arg_buf.len(); - new_arg_buf.truncate(new_len); - new_arg_buf.into_boxed_slice() - }); - self.arg_buf = new_arg_buf; - } + self.arg_buf.push(arg.into()); self } /// Pushes multiple [`TVMArgValue`]s into the function argument buffer. - pub fn args<'b, T: 'b + ?Sized, I>(&mut self, args: I) -> &mut Self + pub fn args(&mut self, args: I) -> &mut Self where - I: IntoIterator, - TVMValue: From<&'b T>, - TVMTypeCode: From<&'b T>, + I: IntoIterator, + TVMArgValue<'a>: From<&'a T>, { - for arg in args { + args.into_iter().for_each(|arg| { self.arg(&arg); - } + }); self } /// Sets an output for a function that requirs a mutable output to be provided. /// See the `basics` in tests for an example. - pub fn set_output<'b, T: 'b + ?Sized>(&mut self, arg: &'b mut T) -> Result<&mut Self> + pub fn set_output(&mut self, ret: T) -> Result<&mut Self> where - TVMValue: From<&'b T>, - TVMTypeCode: From<&'b T>, + TVMRetValue: From, { if self.ret_buf.is_none() { - let tvm_ret = - unsafe { TVMRetValue::from_tvm_value(TVMValue::from(arg), TVMTypeCode::from(arg)) }; - self.ret_buf = Some(tvm_ret); + self.ret_buf = Some(ret.into()) } else { bail!(ErrorKind::AtMostOneReturn) } @@ -217,63 +189,29 @@ impl<'a, 'm> Builder<'a, 'm> { /// Calls the function that created from `Builder`. pub fn invoke(&mut self) -> Result { - self.clone()(()) - } -} - -impl<'a, 'm> FnOnce<((),)> for Builder<'a, 'm> { - type Output = Result; - extern "rust-call" fn call_once(self, _: ((),)) -> Self::Output { if self.func.is_none() { bail!("{}", ErrorKind::FunctionNotFound); } - let mut ret_val = unsafe { mem::uninitialized::() }; - let mut ret_type_code = 0 as c_int; - if self.arg_buf.is_some() { - let arg_buf = self.arg_buf?; - let mut num_args = arg_buf.len(); - let mut values = arg_buf - .iter() - .map(|tav| tav.value.inner) - .collect::>(); - let mut tcodes = arg_buf - .iter() - .map(|tav| tav.type_code as c_int) - .collect::>(); - - if self.ret_buf.is_some() { - num_args = num_args + 1; - let ret_buf = self.ret_buf?; - let (ret_val, ret_type_code) = TVMRetValue::into_tvm_value(ret_buf); - values.append(&mut vec![ret_val.inner]); - tcodes.append(&mut vec![ret_type_code as c_int]); - } - - values.truncate(num_args); - tcodes.truncate(num_args); - check_call!(ts::TVMFuncCall( - self.func?.handle, - values.as_mut_ptr(), - tcodes.as_mut_ptr(), - num_args as c_int, - &mut ret_val as *mut _, - &mut ret_type_code as *mut _ - )); - } else { - check_call!(ts::TVMFuncCall( - self.func?.handle, - ptr::null_mut(), - ptr::null_mut(), - 0 as c_int, - &mut ret_val as *mut _, - &mut ret_type_code as *mut _ - )); - } + let num_args = self.arg_buf.len(); + let (mut values, mut type_codes): (Vec, Vec) = self + .arg_buf + .iter() + .map(|tvm_arg| (tvm_arg.value, tvm_arg.type_code as ffi::TVMTypeCode)) + .unzip(); + + let mut ret_val = TVMValue { v_int64: 4 }; + let mut ret_type_code = 0; + check_call!(ffi::TVMFuncCall( + self.func?.handle, + values.as_mut_ptr(), + type_codes.as_mut_ptr() as *mut i32, + num_args as c_int, + &mut ret_val as *mut _, + &mut ret_type_code as *mut _ + )); - let ret = unsafe { - TVMRetValue::from_tvm_value(TVMValue::new(ret_val), (ret_type_code as i64).into()) - }; + let ret = unsafe { TVMRetValue::from_tvm_value(ret_val.into(), ret_type_code as i64) }; Ok(ret) } } @@ -282,22 +220,22 @@ impl<'a, 'm> FnOnce<((),)> for Builder<'a, 'm> { /// TVM functions. impl<'a, 'm> From<&'m Function> for Builder<'a, 'm> { fn from(func: &'m Function) -> Self { - Builder::new(Some(func), None, None) + Builder::new(Some(func), Vec::new(), None) } } /// Converts a mutable reference of a [`Module`] to [`Builder`]. impl<'a, 'm> From<&'m mut Module> for Builder<'a, 'm> { fn from(module: &'m mut Module) -> Self { - Builder::new(module.entry(), None, None) + Builder::new(module.entry(), Vec::new(), None) } } unsafe extern "C" fn tvm_callback( - args: *mut ts::TVMValue, + args: *mut ffi::TVMValue, type_codes: *mut c_int, num_args: c_int, - ret: ts::TVMRetValueHandle, + ret: ffi::TVMRetValueHandle, fhandle: *mut c_void, ) -> c_int { // turning off the incorrect linter complaints @@ -306,22 +244,19 @@ unsafe extern "C" fn tvm_callback( let args_list = slice::from_raw_parts_mut(args, len); let type_codes_list = slice::from_raw_parts_mut(type_codes, len); let mut local_args: Vec = Vec::new(); - let mut value = mem::uninitialized::(); + let mut value = mem::uninitialized::(); let mut tcode = mem::uninitialized::(); let rust_fn = mem::transmute::<*mut c_void, fn(&[TVMArgValue]) -> Result>(fhandle); for i in 0..len { value = args_list[i]; tcode = type_codes_list[i]; - if tcode == ts::TVMTypeCode_kNodeHandle as c_int - || tcode == ts::TVMTypeCode_kFuncHandle as c_int - || tcode == ts::TVMTypeCode_kModuleHandle as c_int + if tcode == ffi::TVMTypeCode_kNodeHandle as c_int + || tcode == ffi::TVMTypeCode_kFuncHandle as c_int + || tcode == ffi::TVMTypeCode_kModuleHandle as c_int { - check_call!(ts::TVMCbArgToReturn(&mut value as *mut _, tcode)); + check_call!(ffi::TVMCbArgToReturn(&mut value as *mut _, tcode)); } - local_args.push(TVMArgValue::new( - TVMValue::new(value), - (tcode as i64).into(), - )); + local_args.push(TVMArgValue::new(value.into(), (tcode as i64).into())); } let rv = match rust_fn(local_args.as_slice()) { @@ -332,10 +267,9 @@ unsafe extern "C" fn tvm_callback( } }; - let (ret_val, ret_tcode) = TVMRetValue::into_tvm_value(rv); - let mut ret_val = ret_val.inner; + let (mut ret_val, ret_tcode) = rv.into_tvm_value(); let mut ret_type_code = ret_tcode as c_int; - check_call!(ts::TVMCFuncSetReturn( + check_call!(ffi::TVMCFuncSetReturn( ret, &mut ret_val as *mut _, &mut ret_type_code as *mut _, @@ -350,15 +284,15 @@ unsafe extern "C" fn tvm_callback_finalizer(fhandle: *mut c_void) { } fn convert_to_tvm_func(f: fn(&[TVMArgValue]) -> Result) -> Function { - let mut fhandle = ptr::null_mut() as ts::TVMFunctionHandle; + let mut fhandle = ptr::null_mut() as ffi::TVMFunctionHandle; let resource_handle = f as *mut fn(&[TVMArgValue]) -> Result; - check_call!(ts::TVMFuncCreateFromCFunc( + check_call!(ffi::TVMFuncCreateFromCFunc( Some(tvm_callback), resource_handle as *mut c_void, Some(tvm_callback_finalizer), &mut fhandle as *mut _ )); - Function::new(fhandle, false, false) + Function::new(fhandle) } /// Registers a Rust function with signature @@ -397,7 +331,7 @@ pub fn register>( ) -> Result<()> { let func = convert_to_tvm_func(f); let name = CString::new(name.as_ref())?; - check_call!(ts::TVMFuncRegisterGlobal( + check_call!(ffi::TVMFuncRegisterGlobal( name.as_ref().as_ptr() as *const c_char, func.handle(), override_ as c_int diff --git a/rust/frontend/src/lib.rs b/rust/frontend/src/lib.rs index 6e15e4f8d046..3789e1d2607e 100644 --- a/rust/frontend/src/lib.rs +++ b/rust/frontend/src/lib.rs @@ -12,20 +12,12 @@ //! Checkout the `examples` repository for more details. #![crate_name = "tvm_frontend"] -#![recursion_limit = "1024"] #![allow(non_camel_case_types, unused_unsafe)] -#![feature( - try_from, - try_trait, - fn_traits, - unboxed_closures, - box_syntax, - option_replace -)] +#![feature(try_from, try_trait, fn_traits, unboxed_closures, box_syntax)] #[macro_use] extern crate error_chain; -extern crate tvm_common as common; +extern crate tvm_common; #[macro_use] extern crate lazy_static; extern crate ndarray as rust_ndarray; @@ -36,7 +28,19 @@ use std::{ str, }; -use crate::common::ffi::ts; +pub use crate::{ + bytearray::TVMByteArray, + context::{TVMContext, TVMDeviceType}, + errors::*, + function::Function, + module::Module, + ndarray::NDArray, + tvm_common::{ + errors as common_errors, + ffi::{self, TVMType}, + packed_func::{TVMArgValue, TVMRetValue}, + }, +}; // Macro to check the return call to TVM runtime shared library. macro_rules! check_call { @@ -50,7 +54,7 @@ macro_rules! check_call { /// Gets the last error message. pub fn get_last_error() -> &'static str { unsafe { - match CStr::from_ptr(ts::TVMGetLastError()).to_str() { + match CStr::from_ptr(ffi::TVMGetLastError()).to_str() { Ok(s) => s, Err(_) => "Invalid UTF-8 message", } @@ -60,7 +64,7 @@ pub fn get_last_error() -> &'static str { pub(crate) fn set_last_error(err: &Error) { let c_string = CString::new(err.to_string()).unwrap(); unsafe { - ts::TVMAPISetLastError(c_string.as_ptr()); + ffi::TVMAPISetLastError(c_string.as_ptr()); } } @@ -71,27 +75,11 @@ pub mod context; pub mod errors; pub mod module; pub mod ndarray; -pub mod ty; pub mod value; -pub use crate::{ - bytearray::TVMByteArray, - common::{ - errors as common_errors, - ty::TVMTypeCode, - value::{TVMArgValue, TVMRetValue, TVMValue}, - }, - context::{TVMContext, TVMDeviceType}, - errors::*, - function::Function, - module::Module, - ndarray::NDArray, - ty::TVMType, -}; - /// Outputs the current TVM version. pub fn version() -> &'static str { - match str::from_utf8(ts::TVM_VERSION) { + match str::from_utf8(ffi::TVM_VERSION) { Ok(s) => s, Err(_) => "Invalid UTF-8 string", } diff --git a/rust/frontend/src/module.rs b/rust/frontend/src/module.rs index c12d9d48cf13..575cfdaae0fd 100644 --- a/rust/frontend/src/module.rs +++ b/rust/frontend/src/module.rs @@ -8,7 +8,7 @@ use std::{ ptr, }; -use crate::ts; +use tvm_common::ffi; use crate::{function::Function, ErrorKind, Result}; @@ -16,22 +16,18 @@ const ENTRY_FUNC: &'static str = "__tvm_main__"; /// Wrapper around TVM module handle which contains an entry function. /// The entry function can be applied to an imported module through [`entry_func`]. -/// Also [`is_released`] shows whether the module is dropped or not. /// /// [`entry_func`]:struct.Module.html#method.entry_func -/// [`is_released`]:struct.Module.html#method.is_released #[derive(Debug, Clone)] pub struct Module { - pub(crate) handle: ts::TVMModuleHandle, - is_released: bool, + pub(crate) handle: ffi::TVMModuleHandle, entry_func: Option, } impl Module { - pub(crate) fn new(handle: ts::TVMModuleHandle, is_released: bool) -> Self { + pub(crate) fn new(handle: ffi::TVMModuleHandle) -> Self { Self { handle, - is_released, entry_func: None, } } @@ -46,8 +42,8 @@ impl Module { /// Gets a function by name from a registered module. pub fn get_function(&self, name: &str, query_import: bool) -> Result { let name = CString::new(name)?; - let mut fhandle = ptr::null_mut() as ts::TVMFunctionHandle; - check_call!(ts::TVMModGetFunction( + let mut fhandle = ptr::null_mut() as ffi::TVMFunctionHandle; + check_call!(ffi::TVMModGetFunction( self.handle, name.as_ptr() as *const c_char, query_import as c_int, @@ -56,50 +52,42 @@ impl Module { if fhandle.is_null() { bail!(ErrorKind::NullHandle(format!("{}", name.into_string()?))) } else { - Ok(Function::new(fhandle, false, false)) + Ok(Function::new(fhandle)) } } /// Imports a dependent module such as `.ptx` for gpu. pub fn import_module(&self, dependent_module: Module) { - check_call!(ts::TVMModImport(self.handle, dependent_module.handle)) + check_call!(ffi::TVMModImport(self.handle, dependent_module.handle)) } /// Loads a module shared library from path. pub fn load>(path: &P) -> Result { - let ext = path.as_ref().extension()?.to_str()?; - let func = Function::get("module._LoadFromFile", true /* is_global */) - .expect("API function always exists"); - let ret: Module = call_packed!(func, path.as_ref().to_str()?, ext)?.try_into()?; + let ext = CString::new(path.as_ref().extension()?.to_str()?)?; + let func = Function::get("module._LoadFromFile").expect("API function always exists"); + let cpath = CString::new(path.as_ref().to_str()?)?; + let ret: Module = call_packed!(func, &cpath, &ext)?.try_into()?; Ok(ret) } /// Checks if a target device is enabled for a module. pub fn enabled(&self, target: &str) -> bool { - let func = Function::get("module._Enabled", true /* is_global */) - .expect("API function always exists"); + let func = Function::get("module._Enabled").expect("API function always exists"); // `unwrap` is safe here because if there is any error during the // function call, it would occur in `call_packed!`. - let ret: i64 = call_packed!(func, target).unwrap().try_into().unwrap(); + let tgt = CString::new(target).unwrap(); + let ret: i64 = call_packed!(func, &tgt).unwrap().try_into().unwrap(); ret != 0 } /// Returns the underlying module handle. - pub fn handle(&self) -> ts::TVMModuleHandle { + pub fn handle(&self) -> ffi::TVMModuleHandle { self.handle } - - /// Returns true if the underlying module has been dropped and false otherwise. - pub fn is_released(&self) -> bool { - self.is_released - } } impl Drop for Module { fn drop(&mut self) { - if !self.is_released { - check_call!(ts::TVMModFree(self.handle)); - self.is_released = true; - } + check_call!(ffi::TVMModFree(self.handle)); } } diff --git a/rust/frontend/src/ndarray.rs b/rust/frontend/src/ndarray.rs index 44dfcca3b320..0992a597c5e7 100644 --- a/rust/frontend/src/ndarray.rs +++ b/rust/frontend/src/ndarray.rs @@ -27,30 +27,29 @@ use std::{convert::TryFrom, mem, os::raw::c_int, ptr, slice}; use crate::rust_ndarray::{Array, ArrayD}; use num_traits::Num; +use tvm_common::{ffi, TVMType}; -use crate::ts; - -use crate::{Error, ErrorKind, Result, TVMByteArray, TVMContext, TVMType}; +use crate::{Error, ErrorKind, Result, TVMByteArray, TVMContext}; /// See the [`module-level documentation`](../ndarray/index.html) for more details. /// /// Wrapper around TVM array handle. #[derive(Debug)] pub struct NDArray { - pub(crate) handle: ts::TVMArrayHandle, + pub(crate) handle: ffi::TVMArrayHandle, is_view: bool, } impl NDArray { - pub(crate) fn new(handle: ts::TVMArrayHandle, is_view: bool) -> Self { + pub(crate) fn new(handle: ffi::TVMArrayHandle) -> Self { NDArray { handle: handle, - is_view: is_view, + is_view: true, } } /// Returns the underlying array handle. - pub fn handle(&self) -> ts::TVMArrayHandle { + pub fn handle(&self) -> ffi::TVMArrayHandle { self.handle } @@ -176,7 +175,7 @@ impl NDArray { /// *Note*: if something goes wrong during the copy, it will panic /// from TVM side. See `TVMArrayCopyFromBytes` in `include/tvm/runtime/c_runtime_api.h`. pub fn copy_from_buffer(&mut self, data: &mut [T]) { - check_call!(ts::TVMArrayCopyFromBytes( + check_call!(ffi::TVMArrayCopyFromBytes( self.handle, data.as_ptr() as *mut _, data.len() * mem::size_of::() @@ -194,10 +193,10 @@ impl NDArray { ) ); } - check_call!(ts::TVMArrayCopyFromTo( + check_call!(ffi::TVMArrayCopyFromTo( self.handle, target.handle, - ptr::null_mut() as ts::TVMStreamHandle + ptr::null_mut() as ffi::TVMStreamHandle )); Ok(target) } @@ -224,18 +223,21 @@ impl NDArray { /// Allocates and creates an empty NDArray given the shape, context and dtype. pub fn empty(shape: &[usize], ctx: TVMContext, dtype: TVMType) -> NDArray { - let mut handle = ptr::null_mut() as ts::TVMArrayHandle; - check_call!(ts::TVMArrayAlloc( + let mut handle = ptr::null_mut() as ffi::TVMArrayHandle; + check_call!(ffi::TVMArrayAlloc( shape.as_ptr() as *const i64, shape.len() as c_int, - dtype.inner.code as c_int, - dtype.inner.bits as c_int, - dtype.inner.lanes as c_int, + dtype.code as c_int, + dtype.bits as c_int, + dtype.lanes as c_int, ctx.device_type.0 as c_int, ctx.device_id as c_int, &mut handle as *mut _, )); - NDArray::new(handle, false) + NDArray { + handle, + is_view: false, + } } } @@ -272,7 +274,7 @@ impl_from_ndarray_rustndarray!(f32, "float"); impl Drop for NDArray { fn drop(&mut self) { if !self.is_view { - check_call!(ts::TVMArrayFree(self.handle)); + check_call!(ffi::TVMArrayFree(self.handle)); } } } diff --git a/rust/frontend/src/ty.rs b/rust/frontend/src/ty.rs deleted file mode 100644 index 7e912a517e1d..000000000000 --- a/rust/frontend/src/ty.rs +++ /dev/null @@ -1,150 +0,0 @@ -//! This module implements the required conversions from Rust types to TVM types. -//! -//! In TVM frontend only conversions from Rust's 32-bits (POD) numeric types (i32, u32, f32) -//! and 64-bits pointers are supported. - -use std::{ - fmt::{self, Display, Formatter}, - ops::{Deref, DerefMut}, -}; - -use crate::ts; - -use crate::{Function, Module, NDArray, TVMByteArray, TVMContext, TVMDeviceType, TVMTypeCode}; - -macro_rules! impl_prim_type { - ($type:ty, $variant:ident) => { - impl From<$type> for TVMTypeCode { - fn from(_arg: $type) -> Self { - TVMTypeCode::$variant - } - } - - impl<'a> From<&'a $type> for TVMTypeCode { - fn from(_arg: &$type) -> Self { - TVMTypeCode::$variant - } - } - - impl<'a> From<&'a mut $type> for TVMTypeCode { - fn from(_arg: &mut $type) -> Self { - TVMTypeCode::$variant - } - } - }; -} - -impl_prim_type!(TVMDeviceType, kDLInt); -impl_prim_type!(TVMContext, kTVMContext); -impl_prim_type!(TVMType, kTVMType); -impl_prim_type!(Function, kFuncHandle); -impl_prim_type!(Module, kModuleHandle); -impl_prim_type!(NDArray, kArrayHandle); -impl_prim_type!(TVMByteArray, kBytes); - -/// See the [module-level documentation](../ty/index.html) for more details. -/// -/// Wrapper around underlying TVMType -#[derive(Debug, Copy, Clone, PartialEq, Eq)] -pub struct TVMType { - // inner fields are (code: u8, bits: u8, lanes: u16) - pub inner: ts::TVMType, -} - -impl TVMType { - pub(crate) fn new(type_code: u8, bits: u8, lanes: u16) -> Self { - TVMType { - inner: ts::TVMType { - code: type_code, - bits: bits, - lanes: lanes, - }, - } - } -} - -/// Implements TVMType conversion from `&str` of general format `{dtype}{bits}x{lanes}` -/// such as "int32", "float32" or with lane "float32x1". -impl<'a> From<&'a str> for TVMType { - fn from(type_str: &'a str) -> Self { - if type_str == "bool" { - return TVMType::new(1, 1, 1); - } - - let mut type_lanes = type_str.split("x"); - let typ = type_lanes.next().expect("Missing dtype"); - let lanes = type_lanes - .next() - .map(|l| u16::from_str_radix(l, 10).expect(&format!("Bad dtype lanes: {}", l))) - .unwrap_or(1); - let (type_name, bits) = match typ.find(char::is_numeric) { - Some(idx) => { - let (name, bits_str) = typ.split_at(idx); - ( - name, - u8::from_str_radix(bits_str, 10) - .expect(&format!("Bad dtype bits: {}", bits_str)), - ) - } - None => (typ, 32), - }; - - let type_code = match type_name { - "int" => 0, - "uint" => 1, - "float" => 2, - "handle" => 3, - _ => unimplemented!(), - }; - - TVMType::new(type_code, bits, lanes) - } -} - -impl Display for TVMType { - fn fmt(&self, f: &mut Formatter) -> fmt::Result { - let ts::TVMType { code, bits, lanes } = self.inner; - if bits == 1 && lanes == 1 { - return write!(f, "bool"); - } - let mut tcode_str = match code { - 0 => "int", - 1 => "uint", - 2 => "float", - 4 => "handle", - _ => "Unknown", - } - .to_string(); - - tcode_str += &bits.to_string(); - if lanes > 1 { - tcode_str += &format!("x{}", lanes.to_string()); - } - f.write_str(&tcode_str) - } -} - -impl From for ts::DLDataType { - fn from(dtype: TVMType) -> Self { - dtype.inner - } -} - -impl From for TVMType { - fn from(dtype: ts::DLDataType) -> Self { - Self::new(dtype.code, dtype.bits, dtype.lanes) - } -} - -impl Deref for TVMType { - type Target = ts::TVMType; - fn deref(&self) -> &Self::Target { - &self.inner - } -} - -impl DerefMut for TVMType { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.inner - } -} diff --git a/rust/frontend/src/value.rs b/rust/frontend/src/value.rs index 9fad7de4984a..e3b727cdf6e4 100644 --- a/rust/frontend/src/value.rs +++ b/rust/frontend/src/value.rs @@ -2,140 +2,108 @@ //! and their conversions needed for the types used in frontend crate. //! `TVMRetValue` is the owned version of `TVMPODValue`. -use std::{convert::TryFrom, mem, os::raw::c_void}; +use std::{convert::TryFrom, os::raw::c_void}; + +use tvm_common::{ + ensure_type, + ffi::{self, TVMValue}, +}; use crate::{ - common_errors::*, ts, Function, Module, NDArray, TVMArgValue, TVMByteArray, TVMContext, - TVMDeviceType, TVMRetValue, TVMType, TVMTypeCode, TVMValue, + common_errors::*, Function, Module, NDArray, TVMArgValue, TVMByteArray, TVMContext, TVMRetValue, }; macro_rules! impl_tvm_val_from_handle { - ($($ty:ty),+) => { - $( - impl<'a> From<&'a $ty> for TVMValue { - fn from(arg: &$ty) -> Self { - let inner = ts::TVMValue { + ($ty:ident, $type_code:expr, $handle:ty) => { + impl<'a> From<&'a $ty> for TVMArgValue<'a> { + fn from(arg: &$ty) -> Self { + TVMArgValue { + value: TVMValue { v_handle: arg.handle as *mut _ as *mut c_void, - }; - Self::new(inner) + }, + type_code: $type_code as i64, + _lifetime: std::marker::PhantomData, } } - )+ - } -} - -impl_tvm_val_from_handle!(Module, Function, NDArray); - -impl<'a> From<&'a TVMType> for TVMValue { - fn from(ty: &TVMType) -> Self { - let inner = ts::TVMValue { v_type: ty.inner }; - Self::new(inner) - } -} - -impl<'a> From<&'a TVMContext> for TVMValue { - fn from(ctx: &TVMContext) -> Self { - let inner = ts::TVMValue { - v_ctx: ctx.clone().into(), - }; - Self::new(inner) - } -} + } -impl<'a> From<&'a TVMDeviceType> for TVMValue { - fn from(dev: &TVMDeviceType) -> Self { - let inner = ts::TVMValue { - v_int64: dev.0 as i64, - }; - Self::new(inner) - } -} + impl<'a> From<&'a mut $ty> for TVMArgValue<'a> { + fn from(arg: &mut $ty) -> Self { + TVMArgValue { + value: TVMValue { + v_handle: arg.handle as *mut _ as *mut c_void, + }, + type_code: $type_code as i64, + _lifetime: std::marker::PhantomData, + } + } + } -impl<'a> From<&'a TVMByteArray> for TVMValue { - fn from(barr: &TVMByteArray) -> Self { - let inner = ts::TVMValue { - v_handle: &barr.inner as *const ts::TVMByteArray as *mut c_void, - }; - Self::new(inner) - } -} + impl<'a, 'v> TryFrom<&'a TVMArgValue<'v>> for $ty { + type Error = Error; + fn try_from(arg: &TVMArgValue<'v>) -> Result<$ty> { + ensure_type!(arg, $type_code); + Ok($ty::new(unsafe { *(arg.value.v_handle as *const $handle) })) + } + } -impl<'a, 'b> TryFrom<&'b TVMArgValue<'a>> for NDArray { - type Error = Error; - fn try_from(arg: &TVMArgValue<'a>) -> Result { - if arg.type_code == TVMTypeCode::kArrayHandle { - let handle = unsafe { arg.value.inner.v_handle }; - let arr_handle = unsafe { mem::transmute::<*mut c_void, ts::TVMArrayHandle>(handle) }; - Ok(Self::new(arr_handle, true)) - } else { - bail!(ErrorKind::TryFromTVMArgValueError( - stringify!(NDArray).to_string(), - arg.type_code.to_string() - )) + impl From<$ty> for TVMRetValue { + fn from(val: $ty) -> TVMRetValue { + TVMRetValue { + prim_value: 0, + box_value: box val, + type_code: $type_code as i64, + } + } } - } -} -impl<'a, 'b> TryFrom<&'b TVMArgValue<'a>> for Module { - type Error = Error; - fn try_from(arg: &TVMArgValue<'a>) -> Result { - if arg.type_code == TVMTypeCode::kModuleHandle { - let handle = unsafe { arg.value.inner.v_handle }; - Ok(Self::new(handle, false)) - } else { - bail!(ErrorKind::TryFromTVMArgValueError( - stringify!(Module).to_string(), - arg.type_code.to_string() - )) + impl TryFrom for $ty { + type Error = Error; + fn try_from(ret: TVMRetValue) -> Result<$ty> { + if let Ok(handle) = ret.box_value.downcast::<$handle>() { + Ok($ty::new(*handle)) + } else { + bail!(ErrorKind::TryFromTVMRetValueError( + stringify!($type_code).to_string(), + ret.type_code, + )) + } + } } - } + }; } -impl<'a, 'b> TryFrom<&'b TVMArgValue<'a>> for TVMByteArray { - type Error = Error; - fn try_from(arg: &TVMArgValue<'a>) -> Result { - if arg.type_code == TVMTypeCode::kBytes { - unsafe { - let barr_ptr = - mem::transmute::<*mut c_void, *mut ts::TVMByteArray>(arg.value.inner.v_handle); - Ok(Self::new(*barr_ptr)) - } - } else { - bail!(ErrorKind::TryFromTVMArgValueError( - stringify!(TVMByteArray).to_string(), - arg.type_code.to_string() - )) +impl_tvm_val_from_handle!( + Function, + ffi::TVMTypeCode_kFuncHandle, + ffi::TVMFunctionHandle +); +impl_tvm_val_from_handle!(Module, ffi::TVMTypeCode_kModuleHandle, ffi::TVMModuleHandle); +impl_tvm_val_from_handle!(NDArray, ffi::TVMTypeCode_kArrayHandle, ffi::TVMArrayHandle); + +impl<'a> From<&'a TVMByteArray> for TVMValue { + fn from(barr: &TVMByteArray) -> Self { + TVMValue { + v_handle: &barr.inner as *const ffi::TVMByteArray as *mut c_void, } } } -impl<'a, 'b> TryFrom<&'b TVMArgValue<'a>> for TVMType { +impl<'a, 'v> TryFrom<&'a TVMArgValue<'v>> for TVMByteArray { type Error = Error; - fn try_from(arg: &TVMArgValue<'a>) -> Result { - if arg.type_code == TVMTypeCode::kTVMType { - let ty = unsafe { arg.value.inner.v_type }; - Ok(TVMType::from(ty)) - } else { - bail!(ErrorKind::TryFromTVMArgValueError( - stringify!(TVMType).to_string(), - arg.type_code.to_string() - )) - } + fn try_from(arg: &TVMArgValue<'v>) -> Result { + ensure_type!(arg, ffi::TVMTypeCode_kBytes); + Ok(TVMByteArray::new(unsafe { + *(arg.value.v_handle as *mut ffi::TVMByteArray) + })) } } impl<'a, 'b> TryFrom<&'b TVMArgValue<'a>> for TVMContext { type Error = Error; fn try_from(arg: &TVMArgValue<'a>) -> Result { - if arg.type_code == TVMTypeCode::kTVMContext { - let ty = unsafe { arg.value.inner.v_ctx }; - Ok(TVMContext::from(ty)) - } else { - bail!(ErrorKind::TryFromTVMArgValueError( - stringify!(TVMContext).to_string(), - arg.type_code.to_string() - )) - } + ensure_type!(arg, ffi::TVMTypeCode_kTVMContext); + Ok(unsafe { arg.value.v_ctx.into() }) } } @@ -146,7 +114,7 @@ macro_rules! impl_boxed_ret_value { TVMRetValue { prim_value: 0, box_value: box val, - type_code: $code, + type_code: $code as i64, } } } @@ -158,7 +126,7 @@ macro_rules! impl_boxed_ret_value { } else { bail!(ErrorKind::TryFromTVMRetValueError( stringify!($type).to_string(), - ret.type_code.to_string() + ret.type_code )) } } @@ -166,51 +134,9 @@ macro_rules! impl_boxed_ret_value { }; } -impl_boxed_ret_value!(TVMType, TVMTypeCode::kTVMType); -impl_boxed_ret_value!(TVMContext, TVMTypeCode::kTVMContext); -impl_boxed_ret_value!(TVMByteArray, TVMTypeCode::kBytes); - -impl TryFrom for Module { - type Error = Error; - fn try_from(ret: TVMRetValue) -> Result { - if let Ok(handle) = ret.box_value.downcast::() { - Ok(Module::new(*handle, false)) - } else { - bail!(ErrorKind::TryFromTVMRetValueError( - stringify!(TVMTypeCode::kModuleHandle).to_string(), - ret.type_code.to_string() - )) - } - } -} - -impl TryFrom for Function { - type Error = Error; - fn try_from(ret: TVMRetValue) -> Result { - if let Ok(handle) = ret.box_value.downcast::() { - Ok(Function::new(*handle, false, false)) - } else { - bail!(ErrorKind::TryFromTVMRetValueError( - stringify!(TVMTypeCode::kFuncHandle).to_string(), - ret.type_code.to_string() - )) - } - } -} - -impl TryFrom for NDArray { - type Error = Error; - fn try_from(ret: TVMRetValue) -> Result { - if let Ok(handle) = ret.box_value.downcast::() { - Ok(NDArray::new(*handle, false)) - } else { - bail!(ErrorKind::TryFromTVMRetValueError( - stringify!(TVMTypeCode::kArrayHandle).to_string(), - ret.type_code.to_string() - )) - } - } -} +// impl_boxed_ret_value!(TVMType, ffi::TVMTypeCode_kTVMType); +impl_boxed_ret_value!(TVMContext, ffi::TVMTypeCode_kTVMContext); +impl_boxed_ret_value!(TVMByteArray, ffi::TVMTypeCode_kBytes); #[cfg(test)] mod tests { diff --git a/rust/frontend/tests/basics/src/main.rs b/rust/frontend/tests/basics/src/main.rs index 69b948e9117d..9430829b120e 100644 --- a/rust/frontend/tests/basics/src/main.rs +++ b/rust/frontend/tests/basics/src/main.rs @@ -1,6 +1,10 @@ +#![feature(try_from)] + extern crate ndarray as rust_ndarray; extern crate tvm_frontend as tvm; +use std::convert::TryInto; + use tvm::*; fn main() { @@ -23,12 +27,14 @@ fn main() { if cfg!(feature = "gpu") { fadd.import_module(Module::load(&concat!(env!("OUT_DIR"), "/test_add.ptx")).unwrap()); } - function::Builder::from(&mut fadd) + let ret: NDArray = function::Builder::from(&mut fadd) .arg(&arr) .arg(&arr) - .set_output(&mut ret) + .set_output(ret) .unwrap() .invoke() + .unwrap() + .try_into() .unwrap(); assert_eq!(ret.to_vec::().unwrap(), vec![6f32, 8.0]); diff --git a/rust/runtime/Cargo.toml b/rust/runtime/Cargo.toml index d48c0d98c051..cf17c3448855 100644 --- a/rust/runtime/Cargo.toml +++ b/rust/runtime/Cargo.toml @@ -23,7 +23,7 @@ nom = {version = "4.0.0", default-features = false } serde = "1.0.59" serde_derive = "1.0.79" serde_json = "1.0.17" -tvm-common = { version = "0.1.0", path = "../common/", features = ["runtime"] } +tvm-common = { version = "0.1.0", path = "../common/" } [target.'cfg(not(target_env = "sgx"))'.dependencies] num_cpus = "1.8.0" diff --git a/rust/runtime/src/array.rs b/rust/runtime/src/array.rs index 5c49515a0da3..2b948a9940cf 100644 --- a/rust/runtime/src/array.rs +++ b/rust/runtime/src/array.rs @@ -8,16 +8,13 @@ use std::{ }; use ndarray; - -use crate::{ - allocator::Allocation, - errors::*, - ffi::runtime::{ - DLContext, DLDataType, DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, - DLDataTypeCode_kDLUInt, DLDeviceType_kDLCPU, DLTensor as _DLTensor, - }, +use tvm_common::ffi::{ + DLContext, DLDataType, DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLDataTypeCode_kDLUInt, + DLDeviceType_kDLCPU, DLTensor as _DLTensor, }; +use crate::{allocator::Allocation, errors::*}; + /// A `Storage` is a container which holds `Tensor` data. #[derive(PartialEq)] pub enum Storage<'a> { diff --git a/rust/runtime/src/graph.rs b/rust/runtime/src/graph.rs index 0d5e281f3f77..8da94bde9412 100644 --- a/rust/runtime/src/graph.rs +++ b/rust/runtime/src/graph.rs @@ -4,11 +4,13 @@ use nom::{alpha1, digit1, le_i32, le_i64, le_u16, le_u32, le_u64, le_u8, types:: use serde; use serde_json; -use super::{DLTensor, DataType, Module, Storage, TVMContext, Tensor}; use crate::{ - common::value::TVMArgValue, errors::{Error, ErrorKind, Result}, - ffi::runtime::{DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLDataTypeCode_kDLUInt}, + DLTensor, DataType, Module, Storage, TVMContext, Tensor, +}; +use tvm_common::{ + ffi::{DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLDataTypeCode_kDLUInt}, + TVMArgValue, }; // @see `kTVMNDArrayMagic` in `ndarray.h` diff --git a/rust/runtime/src/packed_func.rs b/rust/runtime/src/packed_func.rs index 2fe0086e9a0d..43254f4dcad9 100644 --- a/rust/runtime/src/packed_func.rs +++ b/rust/runtime/src/packed_func.rs @@ -1,16 +1,14 @@ use std::{convert::TryFrom, marker::PhantomData, os::raw::c_void}; -use super::Tensor; -use crate::ffi::runtime::{ - BackendPackedCFunc, DLTensor as _DLTensor, TVMTypeCode_kArrayHandle, - TVMTypeCode_kNDArrayContainer, TVMValue as _TVMValue, +use tvm_common::{ + ffi::{ + BackendPackedCFunc, DLTensor as _DLTensor, TVMTypeCode_kArrayHandle, + TVMTypeCode_kNDArrayContainer, TVMValue as _TVMValue, + }, + TVMArgValue, TVMRetValue, TVMTypeCode, TVMValue, }; -use super::DLTensor; -use crate::{ - common::{TVMArgValue, TVMRetValue, TVMTypeCode, TVMValue}, - errors::*, -}; +use crate::{array::Tensor, errors::*, DLTensor}; pub type PackedFunc = Box TVMRetValue + Send + Sync>; From 36bc545aed6f5f58498d98fc2017637fa18b537e Mon Sep 17 00:00:00 2001 From: Nick Hynes Date: Sun, 17 Feb 2019 01:20:07 +0000 Subject: [PATCH 02/17] Update runtime --- rust/.gitignore | 4 + rust/common/.gitignore | 4 - rust/common/Cargo.toml | 1 + rust/common/build.rs | 4 + rust/common/src/array.rs | 131 +++++++ rust/common/src/c_runtime_api.rs | 354 ------------------- rust/common/src/lib.rs | 2 + rust/common/src/packed_func.rs | 19 +- rust/runtime/.gitignore | 3 - rust/runtime/Cargo.toml | 2 +- rust/runtime/src/array.rs | 216 ++--------- rust/runtime/src/errors.rs | 3 +- rust/runtime/src/graph.rs | 11 +- rust/runtime/src/lib.rs | 18 +- rust/runtime/src/module.rs | 25 +- rust/runtime/src/packed_func.rs | 116 ------ rust/runtime/src/sgx.rs | 2 +- rust/runtime/src/threading.rs | 13 +- rust/runtime/tests/test_nnvm/Cargo.toml | 2 +- rust/runtime/tests/test_nnvm/src/main.rs | 1 + rust/runtime/tests/test_tvm_basic/Cargo.toml | 2 +- 21 files changed, 232 insertions(+), 701 deletions(-) create mode 100644 rust/.gitignore delete mode 100644 rust/common/.gitignore create mode 100644 rust/common/src/array.rs delete mode 100644 rust/common/src/c_runtime_api.rs delete mode 100644 rust/runtime/.gitignore delete mode 100644 rust/runtime/src/packed_func.rs diff --git a/rust/.gitignore b/rust/.gitignore new file mode 100644 index 000000000000..0cc660650780 --- /dev/null +++ b/rust/.gitignore @@ -0,0 +1,4 @@ +target/ +*.rs.bk +Cargo.lock +c_runtime_api.rs diff --git a/rust/common/.gitignore b/rust/common/.gitignore deleted file mode 100644 index 84c2ae99001c..000000000000 --- a/rust/common/.gitignore +++ /dev/null @@ -1,4 +0,0 @@ -target -**/*.rs.bk -Cargo.lock -/tvm-sys/src/bindgen.rs diff --git a/rust/common/Cargo.toml b/rust/common/Cargo.toml index dc1592b62952..3c095df6c692 100644 --- a/rust/common/Cargo.toml +++ b/rust/common/Cargo.toml @@ -9,6 +9,7 @@ bindings = [] [dependencies] error-chain = { version = "0.12.0", default-features = false } +ndarray = "0.12.1" [build-dependencies] bindgen = "0.37.4" diff --git a/rust/common/build.rs b/rust/common/build.rs index 1a45854bc4e3..90e7a11dd12a 100644 --- a/rust/common/build.rs +++ b/rust/common/build.rs @@ -14,6 +14,10 @@ fn main() { "{}/include/tvm/runtime/c_runtime_api.h", env!("TVM_HOME") )) + .header(format!( + "{}/include/tvm/runtime/c_backend_api.h", + env!("TVM_HOME") + )) .clang_arg(format!("-I{}/3rdparty/dlpack/include/", env!("TVM_HOME"))) .blacklist_type("max_align_t") // @see rust-bindgen#550 .layout_tests(false) diff --git a/rust/common/src/array.rs b/rust/common/src/array.rs new file mode 100644 index 000000000000..9dbea1d12e1d --- /dev/null +++ b/rust/common/src/array.rs @@ -0,0 +1,131 @@ +use std::{ + any::TypeId, + convert::TryFrom, + mem, + ops::{Deref, DerefMut}, + os::raw::{c_int, c_void}, + ptr, slice, +}; + +use crate::ffi::{ + DLContext, DLDataType, DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLDataTypeCode_kDLUInt, + DLDeviceType_kDLCPU, DLTensor, +}; + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct DataType { + pub code: usize, + pub bits: usize, + pub lanes: usize, +} + +impl DataType { + /// Returns the number of bytes occupied by an element of this `DataType`. + pub fn itemsize(&self) -> usize { + (self.bits * self.lanes) >> 3 + } + + /// Returns whether this `DataType` represents primitive type `T`. + pub fn is_type(&self) -> bool { + if self.lanes != 1 { + return false; + } + let typ = TypeId::of::(); + (typ == TypeId::of::() && self.code == 0 && self.bits == 32) + || (typ == TypeId::of::() && self.code == 0 && self.bits == 64) + || (typ == TypeId::of::() && self.code == 1 && self.bits == 32) + || (typ == TypeId::of::() && self.code == 1 && self.bits == 64) + || (typ == TypeId::of::() && self.code == 2 && self.bits == 32) + || (typ == TypeId::of::() && self.code == 2 && self.bits == 64) + } + + pub fn code(&self) -> usize { + self.code + } + + pub fn bits(&self) -> usize { + self.bits + } + + pub fn lanes(&self) -> usize { + self.lanes + } +} + +impl<'a> From<&'a DataType> for DLDataType { + fn from(dtype: &'a DataType) -> Self { + Self { + code: dtype.code as u8, + bits: dtype.bits as u8, + lanes: dtype.lanes as u16, + } + } +} + +impl From for DataType { + fn from(dtype: DLDataType) -> Self { + Self { + code: dtype.code as usize, + bits: dtype.bits as usize, + lanes: dtype.lanes as usize, + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq)] +pub struct TVMContext { + pub device_type: usize, + pub device_id: usize, +} + +impl<'a> From<&'a TVMContext> for DLContext { + fn from(ctx: &'a TVMContext) -> Self { + Self { + device_type: ctx.device_type as u32, + device_id: ctx.device_id as i32, + } + } +} + +impl Default for TVMContext { + fn default() -> Self { + Self { + device_type: DLDeviceType_kDLCPU as usize, + device_id: 0, + } + } +} + +/// `From` conversions to `DLTensor` for `ndarray::Array`. +/// Takes a reference to the `ndarray` since `DLTensor` is not owned. +macro_rules! impl_dltensor_from_ndarray { + ($type:ty, $typecode:expr) => { + impl<'a, D: ndarray::Dimension> From<&'a mut ndarray::Array<$type, D>> for DLTensor { + fn from(arr: &'a mut ndarray::Array<$type, D>) -> Self { + DLTensor { + data: arr.as_mut_ptr() as *mut c_void, + ctx: DLContext { + device_type: DLDeviceType_kDLCPU, + device_id: 0, + }, + ndim: arr.ndim() as c_int, + dtype: DLDataType { + code: $typecode as u8, + bits: 8 * mem::size_of::<$type>() as u8, + lanes: 1, + }, + shape: arr.shape().as_ptr() as *const i64 as *mut i64, + strides: arr.strides().as_ptr() as *const isize as *mut i64, + byte_offset: 0, + } + } + } + }; +} + +impl_dltensor_from_ndarray!(f32, DLDataTypeCode_kDLFloat); +impl_dltensor_from_ndarray!(f64, DLDataTypeCode_kDLFloat); +impl_dltensor_from_ndarray!(i32, DLDataTypeCode_kDLInt); +impl_dltensor_from_ndarray!(i64, DLDataTypeCode_kDLInt); +impl_dltensor_from_ndarray!(u32, DLDataTypeCode_kDLUInt); +impl_dltensor_from_ndarray!(u64, DLDataTypeCode_kDLUInt); diff --git a/rust/common/src/c_runtime_api.rs b/rust/common/src/c_runtime_api.rs deleted file mode 100644 index 28e2518d7557..000000000000 --- a/rust/common/src/c_runtime_api.rs +++ /dev/null @@ -1,354 +0,0 @@ -/* automatically generated by rust-bindgen */ - -pub const TVM_VERSION : & 'static [ u8 ; 8usize ] = b"0.5.dev\0" ; pub const DLPACK_VERSION : u32 = 16 ; pub const _STDINT_H : u32 = 1 ; pub const _FEATURES_H : u32 = 1 ; pub const _DEFAULT_SOURCE : u32 = 1 ; pub const __USE_ISOC11 : u32 = 1 ; pub const __USE_ISOC99 : u32 = 1 ; pub const __USE_ISOC95 : u32 = 1 ; pub const __USE_POSIX_IMPLICITLY : u32 = 1 ; pub const _POSIX_SOURCE : u32 = 1 ; pub const _POSIX_C_SOURCE : u32 = 200809 ; pub const __USE_POSIX : u32 = 1 ; pub const __USE_POSIX2 : u32 = 1 ; pub const __USE_POSIX199309 : u32 = 1 ; pub const __USE_POSIX199506 : u32 = 1 ; pub const __USE_XOPEN2K : u32 = 1 ; pub const __USE_XOPEN2K8 : u32 = 1 ; pub const _ATFILE_SOURCE : u32 = 1 ; pub const __USE_MISC : u32 = 1 ; pub const __USE_ATFILE : u32 = 1 ; pub const __USE_FORTIFY_LEVEL : u32 = 0 ; pub const __GLIBC_USE_DEPRECATED_GETS : u32 = 0 ; pub const _STDC_PREDEF_H : u32 = 1 ; pub const __STDC_IEC_559__ : u32 = 1 ; pub const __STDC_IEC_559_COMPLEX__ : u32 = 1 ; pub const __STDC_ISO_10646__ : u32 = 201706 ; pub const __STDC_NO_THREADS__ : u32 = 1 ; pub const __GNU_LIBRARY__ : u32 = 6 ; pub const __GLIBC__ : u32 = 2 ; pub const __GLIBC_MINOR__ : u32 = 27 ; pub const _SYS_CDEFS_H : u32 = 1 ; pub const __glibc_c99_flexarr_available : u32 = 1 ; pub const __WORDSIZE : u32 = 64 ; pub const __WORDSIZE_TIME64_COMPAT32 : u32 = 1 ; pub const __SYSCALL_WORDSIZE : u32 = 64 ; pub const __HAVE_GENERIC_SELECTION : u32 = 1 ; pub const __GLIBC_USE_LIB_EXT2 : u32 = 0 ; pub const __GLIBC_USE_IEC_60559_BFP_EXT : u32 = 0 ; pub const __GLIBC_USE_IEC_60559_FUNCS_EXT : u32 = 0 ; pub const __GLIBC_USE_IEC_60559_TYPES_EXT : u32 = 0 ; pub const _BITS_TYPES_H : u32 = 1 ; pub const _BITS_TYPESIZES_H : u32 = 1 ; pub const __OFF_T_MATCHES_OFF64_T : u32 = 1 ; pub const __INO_T_MATCHES_INO64_T : u32 = 1 ; pub const __RLIM_T_MATCHES_RLIM64_T : u32 = 1 ; pub const __FD_SETSIZE : u32 = 1024 ; pub const _BITS_WCHAR_H : u32 = 1 ; pub const _BITS_STDINT_INTN_H : u32 = 1 ; pub const _BITS_STDINT_UINTN_H : u32 = 1 ; pub const INT8_MIN : i32 = -128 ; pub const INT16_MIN : i32 = -32768 ; pub const INT32_MIN : i32 = -2147483648 ; pub const INT8_MAX : u32 = 127 ; pub const INT16_MAX : u32 = 32767 ; pub const INT32_MAX : u32 = 2147483647 ; pub const UINT8_MAX : u32 = 255 ; pub const UINT16_MAX : u32 = 65535 ; pub const UINT32_MAX : u32 = 4294967295 ; pub const INT_LEAST8_MIN : i32 = -128 ; pub const INT_LEAST16_MIN : i32 = -32768 ; pub const INT_LEAST32_MIN : i32 = -2147483648 ; pub const INT_LEAST8_MAX : u32 = 127 ; pub const INT_LEAST16_MAX : u32 = 32767 ; pub const INT_LEAST32_MAX : u32 = 2147483647 ; pub const UINT_LEAST8_MAX : u32 = 255 ; pub const UINT_LEAST16_MAX : u32 = 65535 ; pub const UINT_LEAST32_MAX : u32 = 4294967295 ; pub const INT_FAST8_MIN : i32 = -128 ; pub const INT_FAST16_MIN : i64 = -9223372036854775808 ; pub const INT_FAST32_MIN : i64 = -9223372036854775808 ; pub const INT_FAST8_MAX : u32 = 127 ; pub const INT_FAST16_MAX : u64 = 9223372036854775807 ; pub const INT_FAST32_MAX : u64 = 9223372036854775807 ; pub const UINT_FAST8_MAX : u32 = 255 ; pub const UINT_FAST16_MAX : i32 = -1 ; pub const UINT_FAST32_MAX : i32 = -1 ; pub const INTPTR_MIN : i64 = -9223372036854775808 ; pub const INTPTR_MAX : u64 = 9223372036854775807 ; pub const UINTPTR_MAX : i32 = -1 ; pub const PTRDIFF_MIN : i64 = -9223372036854775808 ; pub const PTRDIFF_MAX : u64 = 9223372036854775807 ; pub const SIG_ATOMIC_MIN : i32 = -2147483648 ; pub const SIG_ATOMIC_MAX : u32 = 2147483647 ; pub const SIZE_MAX : i32 = -1 ; pub const WINT_MIN : u32 = 0 ; pub const WINT_MAX : u32 = 4294967295 ; pub type __u_char = :: std :: os :: raw :: c_uchar ; pub type __u_short = :: std :: os :: raw :: c_ushort ; pub type __u_int = :: std :: os :: raw :: c_uint ; pub type __u_long = :: std :: os :: raw :: c_ulong ; pub type __int8_t = :: std :: os :: raw :: c_schar ; pub type __uint8_t = :: std :: os :: raw :: c_uchar ; pub type __int16_t = :: std :: os :: raw :: c_short ; pub type __uint16_t = :: std :: os :: raw :: c_ushort ; pub type __int32_t = :: std :: os :: raw :: c_int ; pub type __uint32_t = :: std :: os :: raw :: c_uint ; pub type __int64_t = :: std :: os :: raw :: c_long ; pub type __uint64_t = :: std :: os :: raw :: c_ulong ; pub type __quad_t = :: std :: os :: raw :: c_long ; pub type __u_quad_t = :: std :: os :: raw :: c_ulong ; pub type __intmax_t = :: std :: os :: raw :: c_long ; pub type __uintmax_t = :: std :: os :: raw :: c_ulong ; pub type __dev_t = :: std :: os :: raw :: c_ulong ; pub type __uid_t = :: std :: os :: raw :: c_uint ; pub type __gid_t = :: std :: os :: raw :: c_uint ; pub type __ino_t = :: std :: os :: raw :: c_ulong ; pub type __ino64_t = :: std :: os :: raw :: c_ulong ; pub type __mode_t = :: std :: os :: raw :: c_uint ; pub type __nlink_t = :: std :: os :: raw :: c_ulong ; pub type __off_t = :: std :: os :: raw :: c_long ; pub type __off64_t = :: std :: os :: raw :: c_long ; pub type __pid_t = :: std :: os :: raw :: c_int ; # [ repr ( C ) ] # [ derive ( Debug , Copy , Clone , PartialEq , Eq ) ] pub struct __fsid_t { pub __val : [ :: std :: os :: raw :: c_int ; 2usize ] , } pub type __clock_t = :: std :: os :: raw :: c_long ; pub type __rlim_t = :: std :: os :: raw :: c_ulong ; pub type __rlim64_t = :: std :: os :: raw :: c_ulong ; pub type __id_t = :: std :: os :: raw :: c_uint ; pub type __time_t = :: std :: os :: raw :: c_long ; pub type __useconds_t = :: std :: os :: raw :: c_uint ; pub type __suseconds_t = :: std :: os :: raw :: c_long ; pub type __daddr_t = :: std :: os :: raw :: c_int ; pub type __key_t = :: std :: os :: raw :: c_int ; pub type __clockid_t = :: std :: os :: raw :: c_int ; pub type __timer_t = * mut :: std :: os :: raw :: c_void ; pub type __blksize_t = :: std :: os :: raw :: c_long ; pub type __blkcnt_t = :: std :: os :: raw :: c_long ; pub type __blkcnt64_t = :: std :: os :: raw :: c_long ; pub type __fsblkcnt_t = :: std :: os :: raw :: c_ulong ; pub type __fsblkcnt64_t = :: std :: os :: raw :: c_ulong ; pub type __fsfilcnt_t = :: std :: os :: raw :: c_ulong ; pub type __fsfilcnt64_t = :: std :: os :: raw :: c_ulong ; pub type __fsword_t = :: std :: os :: raw :: c_long ; pub type __ssize_t = :: std :: os :: raw :: c_long ; pub type __syscall_slong_t = :: std :: os :: raw :: c_long ; pub type __syscall_ulong_t = :: std :: os :: raw :: c_ulong ; pub type __loff_t = __off64_t ; pub type __caddr_t = * mut :: std :: os :: raw :: c_char ; pub type __intptr_t = :: std :: os :: raw :: c_long ; pub type __socklen_t = :: std :: os :: raw :: c_uint ; pub type __sig_atomic_t = :: std :: os :: raw :: c_int ; pub type int_least8_t = :: std :: os :: raw :: c_schar ; pub type int_least16_t = :: std :: os :: raw :: c_short ; pub type int_least32_t = :: std :: os :: raw :: c_int ; pub type int_least64_t = :: std :: os :: raw :: c_long ; pub type uint_least8_t = :: std :: os :: raw :: c_uchar ; pub type uint_least16_t = :: std :: os :: raw :: c_ushort ; pub type uint_least32_t = :: std :: os :: raw :: c_uint ; pub type uint_least64_t = :: std :: os :: raw :: c_ulong ; pub type int_fast8_t = :: std :: os :: raw :: c_schar ; pub type int_fast16_t = :: std :: os :: raw :: c_long ; pub type int_fast32_t = :: std :: os :: raw :: c_long ; pub type int_fast64_t = :: std :: os :: raw :: c_long ; pub type uint_fast8_t = :: std :: os :: raw :: c_uchar ; pub type uint_fast16_t = :: std :: os :: raw :: c_ulong ; pub type uint_fast32_t = :: std :: os :: raw :: c_ulong ; pub type uint_fast64_t = :: std :: os :: raw :: c_ulong ; pub type intmax_t = __intmax_t ; pub type uintmax_t = __uintmax_t ; pub type wchar_t = :: std :: os :: raw :: c_int ; - /// \brief CPU device - pub const DLDeviceType_kDLCPU : DLDeviceType = 1 ; - /// \brief CUDA GPU device - pub const DLDeviceType_kDLGPU : DLDeviceType = 2 ; - /// \brief Pinned CUDA GPU device by cudaMallocHost -/// \note kDLCPUPinned = kDLCPU | kDLGPU - pub const DLDeviceType_kDLCPUPinned : DLDeviceType = 3 ; - /// \brief OpenCL devices. - pub const DLDeviceType_kDLOpenCL : DLDeviceType = 4 ; - /// \brief Vulkan buffer for next generation graphics. - pub const DLDeviceType_kDLVulkan : DLDeviceType = 7 ; - /// \brief Metal for Apple GPU. - pub const DLDeviceType_kDLMetal : DLDeviceType = 8 ; - /// \brief Verilog simulator buffer - pub const DLDeviceType_kDLVPI : DLDeviceType = 9 ; - /// \brief ROCm GPUs for AMD GPUs - pub const DLDeviceType_kDLROCM : DLDeviceType = 10 ; - /// \brief Reserved extension device type, -/// used for quickly test extension device -/// The semantics can differ depending on the implementation. - pub const DLDeviceType_kDLExtDev : DLDeviceType = 12 ; - /// \brief The device type in DLContext. - pub type DLDeviceType = u32 ; - /// \brief A Device context for Tensor and operator. - # [ repr ( C ) ] # [ derive ( Debug , Copy , Clone , PartialEq , Eq ) ] pub struct DLContext { - /// \brief The device type used in the device. - pub device_type : DLDeviceType , - /// \brief The device index - pub device_id : :: std :: os :: raw :: c_int , } pub const DLDataTypeCode_kDLInt : DLDataTypeCode = 0 ; pub const DLDataTypeCode_kDLUInt : DLDataTypeCode = 1 ; pub const DLDataTypeCode_kDLFloat : DLDataTypeCode = 2 ; - /// \brief The type code options DLDataType. - pub type DLDataTypeCode = u32 ; - /// \brief The data type the tensor can hold. -/// -/// Examples -/// - float: type_code = 2, bits = 32, lanes=1 -/// - float4(vectorized 4 float): type_code = 2, bits = 32, lanes=4 -/// - int8: type_code = 0, bits = 8, lanes=1 - # [ repr ( C ) ] # [ derive ( Debug , Copy , Clone , PartialEq , Eq ) ] pub struct DLDataType { - /// \brief Type code of base types. - /// We keep it uint8_t instead of DLDataTypeCode for minimal memory - /// footprint, but the value should be one of DLDataTypeCode enum values. - /// - pub code : u8 , - /// \brief Number of bits, common choices are 8, 16, 32. - pub bits : u8 , - /// \brief Number of lanes in the type, used for vector types. - pub lanes : u16 , } - /// \brief Plain C Tensor object, does not manage memory. - # [ repr ( C ) ] # [ derive ( Debug , Copy , Clone , PartialEq , Eq ) ] pub struct DLTensor { - /// \brief The opaque data pointer points to the allocated data. - /// This will be CUDA device pointer or cl_mem handle in OpenCL. - /// This pointer is always aligns to 256 bytes as in CUDA. - pub data : * mut :: std :: os :: raw :: c_void , - /// \brief The device context of the tensor - pub ctx : DLContext , - /// \brief Number of dimensions - pub ndim : :: std :: os :: raw :: c_int , - /// \brief The data type of the pointer - pub dtype : DLDataType , - /// \brief The shape of the tensor - pub shape : * mut i64 , - /// \brief strides of the tensor, - /// can be NULL, indicating tensor is compact. - pub strides : * mut i64 , - /// \brief The offset in bytes to the beginning pointer to data - pub byte_offset : u64 , } - /// \brief C Tensor object, manage memory of DLTensor. This data structure is -/// intended to faciliate the borrowing of DLTensor by another framework. It is -/// not meant to transfer the tensor. When the borrowing framework doesn't need -/// the tensor, it should call the deleter to notify the host that the resource -/// is no longer needed. - # [ repr ( C ) ] # [ derive ( Debug , Copy , Clone , PartialEq , Eq ) ] pub struct DLManagedTensor { - /// \brief DLTensor which is being memory managed - pub dl_tensor : DLTensor , - /// \brief the context of the original host framework of DLManagedTensor in - /// which DLManagedTensor is used in the framework. It can also be NULL. - pub manager_ctx : * mut :: std :: os :: raw :: c_void , - /// \brief Destructor signature void (*)(void*) - this should be called - /// to destruct manager_ctx which holds the DLManagedTensor. It can be NULL - /// if there is no way for the caller to provide a reasonable destructor. - pub deleter : :: std :: option :: Option < unsafe extern "C" fn ( self_ : * mut DLManagedTensor ) > , } - /// \brief type of array index. - pub type tvm_index_t = i64 ; pub const TVMDeviceExtType_kDLAOCL : TVMDeviceExtType = 5 ; pub const TVMDeviceExtType_kDLSDAccel : TVMDeviceExtType = 6 ; pub const TVMDeviceExtType_kOpenGL : TVMDeviceExtType = 11 ; - /// \brief Extension device types in TVM - pub type TVMDeviceExtType = u32 ; pub const TVMTypeCode_kHandle : TVMTypeCode = 3 ; pub const TVMTypeCode_kNull : TVMTypeCode = 4 ; pub const TVMTypeCode_kTVMType : TVMTypeCode = 5 ; pub const TVMTypeCode_kTVMContext : TVMTypeCode = 6 ; pub const TVMTypeCode_kArrayHandle : TVMTypeCode = 7 ; pub const TVMTypeCode_kNodeHandle : TVMTypeCode = 8 ; pub const TVMTypeCode_kModuleHandle : TVMTypeCode = 9 ; pub const TVMTypeCode_kFuncHandle : TVMTypeCode = 10 ; pub const TVMTypeCode_kStr : TVMTypeCode = 11 ; pub const TVMTypeCode_kBytes : TVMTypeCode = 12 ; pub const TVMTypeCode_kNDArrayContainer : TVMTypeCode = 13 ; pub const TVMTypeCode_kExtBegin : TVMTypeCode = 15 ; pub const TVMTypeCode_kNNVMFirst : TVMTypeCode = 16 ; pub const TVMTypeCode_kNNVMLast : TVMTypeCode = 20 ; pub const TVMTypeCode_kExtReserveEnd : TVMTypeCode = 64 ; pub const TVMTypeCode_kExtEnd : TVMTypeCode = 128 ; - /// \brief The type code in TVMType -/// \note TVMType is used in two places. - pub type TVMTypeCode = u32 ; - /// \brief The data type used in TVM Runtime. -/// -/// Examples -/// - float: type_code = 2, bits = 32, lanes=1 -/// - float4(vectorized 4 float): type_code = 2, bits = 32, lanes=4 -/// - int8: type_code = 0, bits = 8, lanes=1 -/// -/// \note Arguments TVM API function always takes bits=64 and lanes=1 - pub type TVMType = DLDataType ; - /// \brief The Device information, abstract away common device types. - pub type TVMContext = DLContext ; - /// \brief The tensor array stucture to TVM API. - pub type TVMArray = DLTensor ; - /// \brief the array handle - pub type TVMArrayHandle = * mut TVMArray ; - /// \brief Union type of values -/// being passed through API and function calls. - # [ repr ( C ) ] # [ derive ( Copy , Clone ) ] pub union TVMValue { pub v_int64 : i64 , pub v_float64 : f64 , pub v_handle : * mut :: std :: os :: raw :: c_void , pub v_str : * const :: std :: os :: raw :: c_char , pub v_type : TVMType , pub v_ctx : TVMContext , _bindgen_union_align : u64 , } - /// \brief Byte array type used to pass in byte array -/// When kBytes is used as data type. - # [ repr ( C ) ] # [ derive ( Debug , Copy , Clone , PartialEq , Eq ) ] pub struct TVMByteArray { pub data : * const :: std :: os :: raw :: c_char , pub size : usize , } - /// \brief Handle to TVM runtime modules. - pub type TVMModuleHandle = * mut :: std :: os :: raw :: c_void ; - /// \brief Handle to packed function handle. - pub type TVMFunctionHandle = * mut :: std :: os :: raw :: c_void ; - /// \brief Handle to hold return value. - pub type TVMRetValueHandle = * mut :: std :: os :: raw :: c_void ; - /// \brief The stream that is specific to device -/// can be NULL, which indicates the default one. - pub type TVMStreamHandle = * mut :: std :: os :: raw :: c_void ; extern "C" { - /// \brief Used for implementing C API function. -/// Set last error message before return. -/// \param msg The error message to be set. - pub fn TVMAPISetLastError ( msg : * const :: std :: os :: raw :: c_char ) ; } extern "C" { - /// \brief return str message of the last error -/// all function in this file will return 0 when success -/// and -1 when an error occured, -/// TVMGetLastError can be called to retrieve the error -/// -/// this function is threadsafe and can be called by different thread -/// \return error info - pub fn TVMGetLastError ( ) -> * const :: std :: os :: raw :: c_char ; } extern "C" { - /// \brief Load module from file. -/// \param file_name The file name to load the module from. -/// \param format The format of the module. -/// \param out The result module -/// -/// \return 0 when success, -1 when failure happens -/// \note The resulting module do not contain import relation. -/// It can be reconstructed by TVMModImport. - pub fn TVMModLoadFromFile ( file_name : * const :: std :: os :: raw :: c_char , format : * const :: std :: os :: raw :: c_char , out : * mut TVMModuleHandle ) -> :: std :: os :: raw :: c_int ; } extern "C" { - /// \brief Add dep to mod's dependency. -/// This allows functions in this module to use modules. -/// -/// \param mod The module handle. -/// \param dep The dependent module to be imported. -/// \return 0 when success, -1 when failure happens - pub fn TVMModImport ( mod_ : TVMModuleHandle , dep : TVMModuleHandle ) -> :: std :: os :: raw :: c_int ; } extern "C" { - /// \brief Get function from the module. -/// \param mod The module handle. -/// \param func_name The name of the function. -/// \param query_imports Whether to query imported modules -/// \param out The result function, can be NULL if it is not available. -/// \return 0 when no error is thrown, -1 when failure happens - pub fn TVMModGetFunction ( mod_ : TVMModuleHandle , func_name : * const :: std :: os :: raw :: c_char , query_imports : :: std :: os :: raw :: c_int , out : * mut TVMFunctionHandle ) -> :: std :: os :: raw :: c_int ; } extern "C" { - /// \brief Free front-end extension type resource. -/// \param handle The extension handle. -/// \param type_code The type of of the extension type. -/// \return 0 when success, -1 when failure happens - pub fn TVMExtTypeFree ( handle : * mut :: std :: os :: raw :: c_void , type_code : :: std :: os :: raw :: c_int ) -> :: std :: os :: raw :: c_int ; } extern "C" { - /// \brief Free the Module -/// \param mod The module to be freed. -/// -/// \note This may not free up the module's resources. -/// If there is active TVMFunctionHandle uses the module -/// Or if this module is imported by another active module. -/// -/// The all functions remains valid until TVMFuncFree is called. -/// \return 0 when success, -1 when failure happens - pub fn TVMModFree ( mod_ : TVMModuleHandle ) -> :: std :: os :: raw :: c_int ; } extern "C" { - /// \brief Free the function when it is no longer needed. -/// \param func The function handle -/// \return 0 when success, -1 when failure happens - pub fn TVMFuncFree ( func : TVMFunctionHandle ) -> :: std :: os :: raw :: c_int ; } extern "C" { - /// \brief Call a Packed TVM Function. -/// -/// \param func node handle of the function. -/// \param arg_values The arguments -/// \param type_codes The type codes of the arguments -/// \param num_args Number of arguments. -/// -/// \param ret_val The return value. -/// \param ret_type_code the type code of return value. -/// -/// \return 0 when success, -1 when failure happens -/// \note TVM calls always exchanges with type bits=64, lanes=1 -/// -/// \note API calls always exchanges with type bits=64, lanes=1 -/// If API call returns container handles (e.g. FunctionHandle) -/// these handles should be managed by the front-end. -/// The front-end need to call free function (e.g. TVMFuncFree) -/// to free these handles. - pub fn TVMFuncCall ( func : TVMFunctionHandle , arg_values : * mut TVMValue , type_codes : * mut :: std :: os :: raw :: c_int , num_args : :: std :: os :: raw :: c_int , ret_val : * mut TVMValue , ret_type_code : * mut :: std :: os :: raw :: c_int ) -> :: std :: os :: raw :: c_int ; } extern "C" { - /// \brief Set the return value of TVMPackedCFunc. -/// -/// This function is called by TVMPackedCFunc to set the return value. -/// When this function is not called, the function returns null by default. -/// -/// \param ret The return value handle, pass by ret in TVMPackedCFunc -/// \param value The value to be returned. -/// \param type_code The type of the value to be returned. -/// \param num_ret Number of return values, for now only 1 is supported. - pub fn TVMCFuncSetReturn ( ret : TVMRetValueHandle , value : * mut TVMValue , type_code : * mut :: std :: os :: raw :: c_int , num_ret : :: std :: os :: raw :: c_int ) -> :: std :: os :: raw :: c_int ; } extern "C" { - /// \brief Inplace translate callback argument value to return value. -/// This is only needed for non-POD arguments. -/// -/// \param value The value to be translated. -/// \param code The type code to be translated. -/// \note This function will do a shallow copy when necessary. -/// -/// \return 0 when success, -1 when failure happens. - pub fn TVMCbArgToReturn ( value : * mut TVMValue , code : :: std :: os :: raw :: c_int ) -> :: std :: os :: raw :: c_int ; } - /// \brief C type of packed function. -/// -/// \param args The arguments -/// \param type_codes The type codes of the arguments -/// \param num_args Number of arguments. -/// \param ret The return value handle. -/// \param resource_handle The handle additional resouce handle from fron-end. -/// \return 0 if success, -1 if failure happens, set error via TVMAPISetLastError. -/// \sa TVMCFuncSetReturn - pub type TVMPackedCFunc = :: std :: option :: Option < unsafe extern "C" fn ( args : * mut TVMValue , type_codes : * mut :: std :: os :: raw :: c_int , num_args : :: std :: os :: raw :: c_int , ret : TVMRetValueHandle , resource_handle : * mut :: std :: os :: raw :: c_void ) -> :: std :: os :: raw :: c_int > ; - /// \brief C callback to free the resource handle in C packed function. -/// \param resource_handle The handle additional resouce handle from fron-end. - pub type TVMPackedCFuncFinalizer = :: std :: option :: Option < unsafe extern "C" fn ( resource_handle : * mut :: std :: os :: raw :: c_void ) > ; - /// \brief Signature for extension function declarer. -/// -/// TVM call this function to get the extension functions -/// The declarer will call register_func to register function and their name. -/// -/// \param register_func_handle The register function -/// \return 0 if success, -1 if failure happens - pub type TVMExtensionFuncDeclarer = :: std :: option :: Option < unsafe extern "C" fn ( register_func_handle : TVMFunctionHandle ) -> :: std :: os :: raw :: c_int > ; extern "C" { - /// \brief Wrap a TVMPackedCFunc to become a FunctionHandle. -/// -/// The resource_handle will be managed by TVM API, until the function is no longer used. -/// -/// \param func The packed C function. -/// \param resource_handle The resource handle from front-end, can be NULL. -/// \param fin The finalizer on resource handle when the FunctionHandle get freed, can be NULL -/// \param out the result function handle. -/// \return 0 when success, -1 when failure happens - pub fn TVMFuncCreateFromCFunc ( func : TVMPackedCFunc , resource_handle : * mut :: std :: os :: raw :: c_void , fin : TVMPackedCFuncFinalizer , out : * mut TVMFunctionHandle ) -> :: std :: os :: raw :: c_int ; } extern "C" { - /// \brief Register the function to runtime's global table. -/// -/// The registered function then can be pulled by the backend by the name. -/// -/// \param name The name of the function. -/// \param f The function to be registered. -/// \param override Whether allow override already registered function. - pub fn TVMFuncRegisterGlobal ( name : * const :: std :: os :: raw :: c_char , f : TVMFunctionHandle , override_ : :: std :: os :: raw :: c_int ) -> :: std :: os :: raw :: c_int ; } extern "C" { - /// \brief Get a global function. -/// -/// \param name The name of the function. -/// \param out the result function pointer, NULL if it does not exist. -/// -/// \note The function handle of global function is managed by TVM runtime, -/// So TVMFuncFree is should not be called when it get deleted. - pub fn TVMFuncGetGlobal ( name : * const :: std :: os :: raw :: c_char , out : * mut TVMFunctionHandle ) -> :: std :: os :: raw :: c_int ; } extern "C" { - /// \brief List all the globally registered function name -/// \param out_size The number of functions -/// \param out_array The array of function names. -/// \return 0 when success, -1 when failure happens - pub fn TVMFuncListGlobalNames ( out_size : * mut :: std :: os :: raw :: c_int , out_array : * mut * mut * const :: std :: os :: raw :: c_char ) -> :: std :: os :: raw :: c_int ; } extern "C" { - /// \brief Allocate a nd-array's memory, -/// including space of shape, of given spec. -/// -/// \param shape The shape of the array, the data content will be copied to out -/// \param ndim The number of dimension of the array. -/// \param dtype_code The type code of the dtype -/// \param dtype_bits The number of bits of dtype -/// \param dtype_lanes The number of lanes in the dtype. -/// \param device_type The device type of context -/// \param device_id The device id of context. -/// \param out The output handle. -/// \return 0 when success, -1 when failure happens - pub fn TVMArrayAlloc ( shape : * const tvm_index_t , ndim : :: std :: os :: raw :: c_int , dtype_code : :: std :: os :: raw :: c_int , dtype_bits : :: std :: os :: raw :: c_int , dtype_lanes : :: std :: os :: raw :: c_int , device_type : :: std :: os :: raw :: c_int , device_id : :: std :: os :: raw :: c_int , out : * mut TVMArrayHandle ) -> :: std :: os :: raw :: c_int ; } extern "C" { - /// \brief Free the TVM Array. -/// \param handle The array handle to be freed. -/// \return 0 when success, -1 when failure happens - pub fn TVMArrayFree ( handle : TVMArrayHandle ) -> :: std :: os :: raw :: c_int ; } extern "C" { - /// \brief Copy array data from CPU byte array. -/// \param handle The array handle. -/// \param data the data pointer -/// \param nbytes The number of bytes to copy. -/// \return 0 when success, -1 when failure happens - pub fn TVMArrayCopyFromBytes ( handle : TVMArrayHandle , data : * mut :: std :: os :: raw :: c_void , nbytes : usize ) -> :: std :: os :: raw :: c_int ; } extern "C" { - /// \brief Copy array data to CPU byte array. -/// \param handle The array handle. -/// \param data the data pointer -/// \param nbytes The number of bytes to copy. -/// \return 0 when success, -1 when failure happens - pub fn TVMArrayCopyToBytes ( handle : TVMArrayHandle , data : * mut :: std :: os :: raw :: c_void , nbytes : usize ) -> :: std :: os :: raw :: c_int ; } extern "C" { - /// \brief Copy the array, both from and to must be valid during the copy. -/// \param from The array to be copied from. -/// \param to The target space. -/// \param stream The stream where the copy happens, can be NULL. -/// \return 0 when success, -1 when failure happens - pub fn TVMArrayCopyFromTo ( from : TVMArrayHandle , to : TVMArrayHandle , stream : TVMStreamHandle ) -> :: std :: os :: raw :: c_int ; } extern "C" { - /// \brief Produce an array from the DLManagedTensor that shares data memory -/// with the DLManagedTensor. -/// \param from The source DLManagedTensor. -/// \param out The output array handle. -/// \return 0 when success, -1 when failure happens - pub fn TVMArrayFromDLPack ( from : * mut DLManagedTensor , out : * mut TVMArrayHandle ) -> :: std :: os :: raw :: c_int ; } extern "C" { - /// \brief Produce a DLMangedTensor from the array that shares data memory with -/// the array. -/// \param from The source array. -/// \param out The DLManagedTensor handle. -/// \return 0 when success, -1 when failure happens - pub fn TVMArrayToDLPack ( from : TVMArrayHandle , out : * mut * mut DLManagedTensor ) -> :: std :: os :: raw :: c_int ; } extern "C" { - /// \brief Delete (free) a DLManagedTensor's data. -/// \param dltensor Pointer to the DLManagedTensor. - pub fn TVMDLManagedTensorCallDeleter ( dltensor : * mut DLManagedTensor ) ; } extern "C" { - /// \brief Create a new runtime stream. -/// -/// \param device_type The device type of context -/// \param device_id The device id of context -/// \param out The new stream handle -/// \return 0 when success, -1 when failure happens - pub fn TVMStreamCreate ( device_type : :: std :: os :: raw :: c_int , device_id : :: std :: os :: raw :: c_int , out : * mut TVMStreamHandle ) -> :: std :: os :: raw :: c_int ; } extern "C" { - /// \brief Free a created stream handle. -/// -/// \param device_type The device type of context -/// \param device_id The device id of context -/// \param stream The stream to be freed -/// \return 0 when success, -1 when failure happens - pub fn TVMStreamFree ( device_type : :: std :: os :: raw :: c_int , device_id : :: std :: os :: raw :: c_int , stream : TVMStreamHandle ) -> :: std :: os :: raw :: c_int ; } extern "C" { - /// \brief Set the runtime stream of current thread to be stream. -/// The subsequent calls to the same device_type -/// will use the setted stream handle. -/// The specific type of stream is runtime device dependent. -/// -/// \param device_type The device type of context -/// \param device_id The device id of context. -/// \param handle The stream handle. -/// \return 0 when success, -1 when failure happens - pub fn TVMSetStream ( device_type : :: std :: os :: raw :: c_int , device_id : :: std :: os :: raw :: c_int , handle : TVMStreamHandle ) -> :: std :: os :: raw :: c_int ; } extern "C" { - /// \brief Wait until all computations on stream completes. -/// -/// \param device_type The device type of context -/// \param device_id The device id of context. -/// \param stream The stream to be synchronized. -/// \return 0 when success, -1 when failure happens - pub fn TVMSynchronize ( device_type : :: std :: os :: raw :: c_int , device_id : :: std :: os :: raw :: c_int , stream : TVMStreamHandle ) -> :: std :: os :: raw :: c_int ; } extern "C" { - /// \brief Synchronize two streams of execution. -/// -/// \param device_type The device type of context -/// \param device_id The device id of context -/// \param src The source stream to synchronize. -/// \param dst The destination stream to synchronize. -/// \return 0 when success, -1 when failure happens - pub fn TVMStreamStreamSynchronize ( device_type : :: std :: os :: raw :: c_int , device_id : :: std :: os :: raw :: c_int , src : TVMStreamHandle , dst : TVMStreamHandle ) -> :: std :: os :: raw :: c_int ; } \ No newline at end of file diff --git a/rust/common/src/lib.rs b/rust/common/src/lib.rs index bb1034c20d73..f3d3ac915ee6 100644 --- a/rust/common/src/lib.rs +++ b/rust/common/src/lib.rs @@ -21,7 +21,9 @@ pub mod ffi { extern "C" fn(args: *const TVMValue, type_codes: *const c_int, num_args: c_int) -> c_int; } +pub mod array; pub mod errors; +#[macro_use] pub mod packed_func; pub mod value; diff --git a/rust/common/src/packed_func.rs b/rust/common/src/packed_func.rs index cef5a4dba826..8d02650debdd 100644 --- a/rust/common/src/packed_func.rs +++ b/rust/common/src/packed_func.rs @@ -1,5 +1,6 @@ use std::{any::Any, convert::TryFrom, marker::PhantomData, os::raw::c_void}; +pub use crate::ffi::TVMValue; use crate::{errors::*, ffi::*, value::*}; pub type PackedFunc = Box TVMRetValue + Send + Sync>; @@ -309,21 +310,3 @@ impl_prim_ret_value!(f64, 2); impl_prim_ret_value!(isize, 0); impl_prim_ret_value!(usize, 1); impl_boxed_ret_value!(String, 11); - -// @see `WrapPackedFunc` in `llvm_module.cc`. -pub(super) fn wrap_backend_packed_func(func: BackendPackedCFunc) -> PackedFunc { - box move |args: &[TVMArgValue]| { - func( - args.iter() - .map(|ref arg| arg.value) - .collect::>() - .as_ptr(), - args.iter() - .map(|ref arg| arg.type_code as i32) - .collect::>() - .as_ptr() as *const i32, - args.len() as i32, - ); - TVMRetValue::default() - } -} diff --git a/rust/runtime/.gitignore b/rust/runtime/.gitignore deleted file mode 100644 index 230ab66104df..000000000000 --- a/rust/runtime/.gitignore +++ /dev/null @@ -1,3 +0,0 @@ -Cargo.lock -target/ -**/*.rs.bk diff --git a/rust/runtime/Cargo.toml b/rust/runtime/Cargo.toml index cf17c3448855..dcb7de5836f3 100644 --- a/rust/runtime/Cargo.toml +++ b/rust/runtime/Cargo.toml @@ -18,7 +18,7 @@ bounded-spsc-queue = "0.4.0" error-chain = { version = "0.12.0", default-features = false } itertools = "0.7.8" lazy_static = "1.1.0" -ndarray = "0.11.2" +ndarray="0.12.1" nom = {version = "4.0.0", default-features = false } serde = "1.0.59" serde_derive = "1.0.79" diff --git a/rust/runtime/src/array.rs b/rust/runtime/src/array.rs index 2b948a9940cf..eceba761535d 100644 --- a/rust/runtime/src/array.rs +++ b/rust/runtime/src/array.rs @@ -1,16 +1,12 @@ -use std::{ - any::TypeId, - convert::TryFrom, - mem, - ops::{Deref, DerefMut}, - os::raw::{c_int, c_void}, - ptr, slice, -}; +use std::{convert::TryFrom, mem, os::raw::c_void, ptr, slice}; use ndarray; -use tvm_common::ffi::{ - DLContext, DLDataType, DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLDataTypeCode_kDLUInt, - DLDeviceType_kDLCPU, DLTensor as _DLTensor, +use tvm_common::{ + array::{DataType, TVMContext}, + ffi::{ + DLContext, DLDataType, DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, + DLDataTypeCode_kDLUInt, DLTensor, + }, }; use crate::{allocator::Allocation, errors::*}; @@ -234,6 +230,27 @@ impl<'a> Tensor<'a> { byte_offset: 0, } } + + pub(crate) fn as_dltensor(&self, flatten: bool) -> DLTensor { + assert!(!flatten || self.is_contiguous()); + DLTensor { + data: unsafe { self.data.as_mut_ptr().offset(self.byte_offset) } as *mut c_void, + ctx: DLContext::from(&self.ctx), + ndim: if flatten { 1 } else { self.shape.len() } as i32, + dtype: DLDataType::from(&self.dtype), + shape: if flatten { + &self.size as *const _ as *mut i64 + } else { + self.shape.as_ptr() + } as *mut i64, + strides: if flatten || self.is_contiguous() { + ptr::null_mut() + } else { + self.strides.as_ref().unwrap().as_ptr() + } as *mut i64, + byte_offset: 0, + } + } } /// Conversions to `ndarray::Array` from `Tensor`, if the types match. @@ -260,120 +277,9 @@ macro_rules! impl_ndarray_try_from_tensor { }; } -impl_ndarray_try_from_tensor!(i32, DTYPE_INT32); -impl_ndarray_try_from_tensor!(u32, DTYPE_UINT32); -impl_ndarray_try_from_tensor!(f32, DTYPE_FLOAT32); -impl_ndarray_try_from_tensor!(f64, DTYPE_FLOAT64); - -pub struct DLTensor { - pub(crate) inner: _DLTensor, -} - -impl Deref for DLTensor { - type Target = _DLTensor; - fn deref(&self) -> &Self::Target { - &self.inner - } -} - -impl DerefMut for DLTensor { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.inner - } -} - -impl DLTensor { - pub(crate) fn new(raw: _DLTensor) -> Self { - Self { inner: raw } - } - - pub(crate) fn from_tensor<'a>(tensor: &'a Tensor, flatten: bool) -> Self { - assert!(!flatten || tensor.is_contiguous()); - Self { - inner: _DLTensor { - data: unsafe { tensor.data.as_mut_ptr().offset(tensor.byte_offset) } as *mut c_void, - ctx: DLContext::from(&tensor.ctx), - ndim: if flatten { 1 } else { tensor.shape.len() } as i32, - dtype: DLDataType::from(&tensor.dtype), - shape: if flatten { - &tensor.size as *const _ as *mut i64 - } else { - tensor.shape.as_ptr() - } as *mut i64, - strides: if flatten || tensor.is_contiguous() { - ptr::null_mut() - } else { - tensor.strides.as_ref().unwrap().as_ptr() - } as *mut i64, - byte_offset: 0, - }, - } - } -} - -impl<'a, 't> From<&'a Tensor<'t>> for DLTensor { - fn from(tensor: &'a Tensor<'t>) -> Self { - DLTensor::from_tensor(tensor, false /* flatten */) - } -} - -impl<'a, 't> From<&'a mut Tensor<'t>> for DLTensor { - fn from(tensor: &'a mut Tensor<'t>) -> Self { - DLTensor::from_tensor(tensor, false /* flatten */) - } -} - -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub struct DataType { - pub(crate) code: usize, - pub(crate) bits: usize, - pub(crate) lanes: usize, -} - -impl DataType { - /// Returns the number of bytes occupied by an element of this `DataType`. - pub fn itemsize(&self) -> usize { - (self.bits * self.lanes) >> 3 - } - - /// Returns whether this `DataType` represents primitive type `T`. - pub fn is_type(&self) -> bool { - if self.lanes != 1 { - return false; - } - let typ = TypeId::of::(); - (typ == TypeId::of::() && self.code == 0 && self.bits == 32) - || (typ == TypeId::of::() && self.code == 0 && self.bits == 64) - || (typ == TypeId::of::() && self.code == 1 && self.bits == 32) - || (typ == TypeId::of::() && self.code == 1 && self.bits == 64) - || (typ == TypeId::of::() && self.code == 2 && self.bits == 32) - || (typ == TypeId::of::() && self.code == 2 && self.bits == 64) - } -} - -impl<'a> From<&'a DataType> for DLDataType { - fn from(dtype: &'a DataType) -> Self { - Self { - code: dtype.code as u8, - bits: dtype.bits as u8, - lanes: dtype.lanes as u16, - } - } -} - -impl From for DataType { - fn from(dtype: DLDataType) -> Self { - Self { - code: dtype.code as usize, - bits: dtype.bits as usize, - lanes: dtype.lanes as usize, - } - } -} - macro_rules! make_dtype_const { ($name: ident, $code: ident, $bits: expr, $lanes: expr) => { - const $name: DataType = DataType { + pub const $name: DataType = DataType { code: $code as usize, bits: $bits, lanes: $lanes, @@ -386,28 +292,20 @@ make_dtype_const!(DTYPE_UINT32, DLDataTypeCode_kDLUInt, 32, 1); // make_dtype_const!(DTYPE_FLOAT16, DLDataTypeCode_kDLFloat, 16, 1); make_dtype_const!(DTYPE_FLOAT32, DLDataTypeCode_kDLFloat, 32, 1); make_dtype_const!(DTYPE_FLOAT64, DLDataTypeCode_kDLFloat, 64, 1); +impl_ndarray_try_from_tensor!(i32, DTYPE_INT32); +impl_ndarray_try_from_tensor!(u32, DTYPE_UINT32); +impl_ndarray_try_from_tensor!(f32, DTYPE_FLOAT32); +impl_ndarray_try_from_tensor!(f64, DTYPE_FLOAT64); -#[derive(Debug, Clone, Copy, PartialEq)] -pub struct TVMContext { - pub(crate) device_type: usize, - pub(crate) device_id: usize, -} - -impl<'a> From<&'a TVMContext> for DLContext { - fn from(ctx: &'a TVMContext) -> Self { - Self { - device_type: ctx.device_type as u32, - device_id: ctx.device_id as i32, - } +impl<'a, 't> From<&'a Tensor<'t>> for DLTensor { + fn from(tensor: &'a Tensor<'t>) -> Self { + Tensor::as_dltensor(tensor, false /* flatten */) } } -impl Default for TVMContext { - fn default() -> Self { - Self { - device_type: DLDeviceType_kDLCPU as usize, - device_id: 0, - } +impl<'a, 't> From<&'a mut Tensor<'t>> for DLTensor { + fn from(tensor: &'a mut Tensor<'t>) -> Self { + Tensor::as_dltensor(tensor, false /* flatten */) } } @@ -460,42 +358,6 @@ macro_rules! impl_tensor_from_ndarray { }; } -/// `From` conversions to `DLTensor` for `ndarray::Array`. -/// Takes a reference to the `ndarray` since `DLTensor` is not owned. -macro_rules! impl_dltensor_from_ndarray { - ($type:ty, $typecode:expr) => { - impl<'a, D: ndarray::Dimension> From<&'a mut ndarray::Array<$type, D>> for DLTensor { - fn from(arr: &'a mut ndarray::Array<$type, D>) -> Self { - DLTensor { - inner: _DLTensor { - data: arr.as_mut_ptr() as *mut c_void, - ctx: DLContext { - device_type: DLDeviceType_kDLCPU, - device_id: 0, - }, - ndim: arr.ndim() as c_int, - dtype: DLDataType { - code: $typecode as u8, - bits: 8 * mem::size_of::<$type>() as u8, - lanes: 1, - }, - shape: arr.shape().as_ptr() as *const i64 as *mut i64, - strides: arr.strides().as_ptr() as *const isize as *mut i64, - byte_offset: 0, - }, - } - } - } - }; -} - -impl_dltensor_from_ndarray!(f32, DLDataTypeCode_kDLFloat); -impl_dltensor_from_ndarray!(f64, DLDataTypeCode_kDLFloat); -impl_dltensor_from_ndarray!(i32, DLDataTypeCode_kDLInt); -impl_dltensor_from_ndarray!(i64, DLDataTypeCode_kDLInt); -impl_dltensor_from_ndarray!(u32, DLDataTypeCode_kDLUInt); -impl_dltensor_from_ndarray!(u64, DLDataTypeCode_kDLUInt); - impl_tensor_from_ndarray!(f32, DLDataTypeCode_kDLFloat); impl_tensor_from_ndarray!(f64, DLDataTypeCode_kDLFloat); impl_tensor_from_ndarray!(i32, DLDataTypeCode_kDLInt); diff --git a/rust/runtime/src/errors.rs b/rust/runtime/src/errors.rs index cf7723034882..ab56d2580fd2 100644 --- a/rust/runtime/src/errors.rs +++ b/rust/runtime/src/errors.rs @@ -4,7 +4,6 @@ use alloc::alloc; use std::alloc; use std::num; -use crate::common::errors as common_errors; use ndarray; use serde_json; @@ -25,7 +24,7 @@ error_chain! { GraphDeserialize(serde_json::Error); ParseInt(num::ParseIntError); ShapeError(ndarray::ShapeError); - CommonError(common_errors::Error); + CommonError(tvm_common::errors::Error); } } diff --git a/rust/runtime/src/graph.rs b/rust/runtime/src/graph.rs index 8da94bde9412..8f2d5630a5b3 100644 --- a/rust/runtime/src/graph.rs +++ b/rust/runtime/src/graph.rs @@ -6,10 +6,11 @@ use serde_json; use crate::{ errors::{Error, ErrorKind, Result}, - DLTensor, DataType, Module, Storage, TVMContext, Tensor, + Module, Storage, Tensor, }; use tvm_common::{ - ffi::{DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLDataTypeCode_kDLUInt}, + array::{DataType, TVMContext}, + ffi::{DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLDataTypeCode_kDLUInt, DLTensor}, TVMArgValue, }; @@ -199,10 +200,10 @@ impl<'m, 't> GraphExecutor<'m, 't> { }) .collect::>>()?; - let align = dtypes.iter().map(|dtype| dtype.bits as usize).max(); + let align = dtypes.iter().map(|dtype| dtype.bits() as usize).max(); let mut storage_num_bytes = vec![0usize; *storage_ids.iter().max().unwrap_or(&1) + 1]; for (i, &storage_id) in storage_ids.iter().enumerate() { - let dtype_size = dtypes[i].bits * dtypes[i].lanes >> 3; + let dtype_size = dtypes[i].bits() * dtypes[i].lanes() >> 3; let nbytes = dtype_size * shapes[i].iter().product::() as usize; storage_num_bytes[storage_id] = cmp::max(nbytes, storage_num_bytes[storage_id]); } @@ -266,7 +267,7 @@ impl<'m, 't> GraphExecutor<'m, 't> { .map(|idx| { let tensor = &tensors[idx?]; Ok(if attrs.flatten_data { - DLTensor::from_tensor(tensor, true /* flatten */) + Tensor::as_dltensor(tensor, true /* flatten */) } else { DLTensor::from(tensor) }) diff --git a/rust/runtime/src/lib.rs b/rust/runtime/src/lib.rs index da030bc4be65..9716cda3bf46 100644 --- a/rust/runtime/src/lib.rs +++ b/rust/runtime/src/lib.rs @@ -39,27 +39,29 @@ extern crate serde; #[macro_use] extern crate serde_derive; extern crate serde_json; -extern crate tvm_common as common; +extern crate tvm_common; mod allocator; mod array; pub mod errors; -mod module; -#[macro_use] -mod packed_func; mod graph; +mod module; #[cfg(target_env = "sgx")] #[macro_use] pub mod sgx; mod threading; mod workspace; -pub use crate::common::{errors::*, ffi, TVMArgValue, TVMRetValue}; - -pub use self::{ - array::*, errors::*, graph::*, module::*, packed_func::*, threading::*, workspace::*, +pub use tvm_common::{ + call_packed, + errors::*, + ffi::{self, DLTensor}, + packed_func::{self, *}, + TVMArgValue, TVMRetValue, }; +pub use self::{array::*, errors::*, graph::*, module::*, threading::*, workspace::*}; + #[cfg(target_env = "sgx")] use self::sgx::ocall_packed_func; diff --git a/rust/runtime/src/module.rs b/rust/runtime/src/module.rs index 8e6f7d665dd4..10f17f59015e 100644 --- a/rust/runtime/src/module.rs +++ b/rust/runtime/src/module.rs @@ -1,10 +1,9 @@ use std::{ collections::HashMap, convert::AsRef, ffi::CStr, os::raw::c_char, string::String, sync::Mutex, }; - -use crate::{ - ffi::runtime::BackendPackedCFunc, - packed_func::{wrap_backend_packed_func, PackedFunc}, +use tvm_common::{ + ffi::BackendPackedCFunc, + packed_func::{PackedFunc, TVMArgValue, TVMRetValue, TVMValue}, }; pub trait Module { @@ -34,6 +33,24 @@ impl Default for SystemLibModule { } } +// @see `WrapPackedFunc` in `llvm_module.cc`. +pub(super) fn wrap_backend_packed_func(func: BackendPackedCFunc) -> PackedFunc { + box move |args: &[TVMArgValue]| { + func( + args.iter() + .map(|ref arg| arg.value) + .collect::>() + .as_ptr(), + args.iter() + .map(|ref arg| arg.type_code as i32) + .collect::>() + .as_ptr() as *const i32, + args.len() as i32, + ); + TVMRetValue::default() + } +} + #[no_mangle] pub extern "C" fn TVMBackendRegisterSystemLibSymbol( cname: *const c_char, diff --git a/rust/runtime/src/packed_func.rs b/rust/runtime/src/packed_func.rs deleted file mode 100644 index 43254f4dcad9..000000000000 --- a/rust/runtime/src/packed_func.rs +++ /dev/null @@ -1,116 +0,0 @@ -use std::{convert::TryFrom, marker::PhantomData, os::raw::c_void}; - -use tvm_common::{ - ffi::{ - BackendPackedCFunc, DLTensor as _DLTensor, TVMTypeCode_kArrayHandle, - TVMTypeCode_kNDArrayContainer, TVMValue as _TVMValue, - }, - TVMArgValue, TVMRetValue, TVMTypeCode, TVMValue, -}; - -use crate::{array::Tensor, errors::*, DLTensor}; - -pub type PackedFunc = Box TVMRetValue + Send + Sync>; - -/// Calls a packed function and returns a `TVMRetValue`. -/// -/// # Example -/// -/// `call_packed!(my_tvm_func, &mut arg1, &mut arg2)` -#[macro_export] -macro_rules! call_packed { - ($fn:expr, $($args:expr),+) => { - $fn(&[$($args.into(),)+]) - }; - ($fn:expr) => { - $fn(&Vec::new()) - }; -} - -impl<'a> From<&'a DLTensor> for TVMArgValue<'a> { - fn from(arr: &'a DLTensor) -> Self { - let raw = _TVMValue { - v_handle: arr as *const _ as *mut DLTensor as *mut c_void, - }; - TVMArgValue { - value: TVMValue::new(raw), - type_code: TVMTypeCode::kArrayHandle, - lifetime: PhantomData, - } - } -} - -impl<'a> From<&'a mut DLTensor> for TVMArgValue<'a> { - fn from(arr: &'a mut DLTensor) -> Self { - let raw = _TVMValue { - v_handle: arr as *mut _ as *mut c_void, - }; - TVMArgValue { - value: TVMValue::new(raw), - type_code: TVMTypeCode::kArrayHandle, - lifetime: PhantomData, - } - } -} - -impl<'a> TryFrom> for Tensor<'a> { - type Error = Error; - fn try_from(val: TVMArgValue<'a>) -> Result { - ensure!( - val.type_code == TVMTypeCode::kArrayHandle - || val.type_code == TVMTypeCode::kNDArrayContainer, - "Could not downcast arg. Expected `{}` or `{}`, but got `{}`", - TVMTypeCode::kArrayHandle, - TVMTypeCode::kNDArrayContainer, - val.type_code, - ); - - let dlt = unsafe { *(val.value.v_handle as *mut _DLTensor as *const _DLTensor) }; - Ok(DLTensor::new(dlt).into()) - } -} - -impl<'a, 't> From<&'t Tensor<'a>> for TVMRetValue { - fn from(val: &'t Tensor<'a>) -> Self { - TVMRetValue { - prim_value: 0, - box_value: box DLTensor::from(val), - type_code: TVMTypeCode::kNDArrayContainer, - } - } -} - -impl<'a> TryFrom for Tensor<'a> { - type Error = Error; - fn try_from(ret: TVMRetValue) -> Result { - ensure!( - ret.type_code == TVMTypeCode::kArrayHandle - || ret.type_code == TVMTypeCode::kNDArrayContainer, - "Could not downcast arg. Expected `{}` or `{}`, but got `{}`", - TVMTypeCode_kArrayHandle, - TVMTypeCode_kNDArrayContainer, - ret.type_code, - ); - - let dlt = unsafe { *(ret.prim_value as *mut _DLTensor as *const _DLTensor) }; - Ok(DLTensor::new(dlt).into()) - } -} - -// @see `WrapPackedFunc` in `llvm_module.cc`. -pub(crate) fn wrap_backend_packed_func(func: BackendPackedCFunc) -> PackedFunc { - box move |args: &[TVMArgValue]| { - func( - args.iter() - .map(|ref arg| arg.value.inner) - .collect::>() - .as_ptr(), - args.iter() - .map(|ref arg| arg.type_code as i32) - .collect::>() - .as_ptr() as *const i32, - args.len() as i32, - ); - TVMRetValue::default() - } -} diff --git a/rust/runtime/src/sgx.rs b/rust/runtime/src/sgx.rs index 1edf3ef497e7..1c5cb1ae4a40 100644 --- a/rust/runtime/src/sgx.rs +++ b/rust/runtime/src/sgx.rs @@ -4,7 +4,7 @@ use std::{ }; use errors::Result; -use ffi::runtime::TVMValue; +use ffi::TVMValue; use runtime::{threading::sgx_join_threads, SystemLibModule, TVMArgValue, TVMRetValue}; pub use runtime::threading::tvm_run_worker as run_worker; diff --git a/rust/runtime/src/threading.rs b/rust/runtime/src/threading.rs index 38f4b7d23f0f..e6b0a6d11698 100644 --- a/rust/runtime/src/threading.rs +++ b/rust/runtime/src/threading.rs @@ -1,7 +1,7 @@ use std::{ os::raw::{c_int, c_void}, sync::{ - atomic::{AtomicUsize, Ordering, ATOMIC_USIZE_INIT}, + atomic::{AtomicUsize, Ordering}, Arc, Barrier, }, }; @@ -18,8 +18,9 @@ use std::{ use std::{collections::VecDeque, ptr, sync::Mutex}; use bounded_spsc_queue::{self, Producer}; +use tvm_common::ffi::TVMParallelGroupEnv; -use crate::{errors::*, ffi::runtime::TVMParallelGroupEnv}; +use crate::errors::*; #[cfg(target_env = "sgx")] use super::{sgx::ocall_packed_func, TVMArgValue, TVMRetValue}; @@ -251,7 +252,7 @@ pub extern "C" fn TVMBackendParallelLaunch( cb: cb, cdata: cdata, req_num_tasks: num_task, - pending: Arc::new(ATOMIC_USIZE_INIT), + pending: Arc::new(AtomicUsize::new(0)), }); }); } @@ -273,7 +274,7 @@ pub(crate) fn sgx_join_threads() { cb: poison_pill, cdata: ptr::null(), req_num_tasks: 0, - pending: Arc::new(ATOMIC_USIZE_INIT), + pending: Arc::new(AtomicUsize::new(0)), }); }); ocall_packed!("__sgx_thread_group_join__", 0); @@ -322,8 +323,8 @@ mod tests { #[test] fn test_parallel_launch() { TVMBackendParallelLaunch(flambda, ptr::null(), 6); - let counter = ATOMIC_USIZE_INIT; - let task_ids_sum = ATOMIC_USIZE_INIT; + let counter = AtomicUsize::new(0); + let task_ids_sum = AtomicUsize::new(0); let cdata = (counter, task_ids_sum); let num_tasks = 3; TVMBackendParallelLaunch(flambda, &cdata as *const _ as *const c_void, num_tasks); diff --git a/rust/runtime/tests/test_nnvm/Cargo.toml b/rust/runtime/tests/test_nnvm/Cargo.toml index 14d0b3961ad3..259af2341d95 100644 --- a/rust/runtime/tests/test_nnvm/Cargo.toml +++ b/rust/runtime/tests/test_nnvm/Cargo.toml @@ -5,7 +5,7 @@ license = "Apache-2.0" authors = ["TVM Contributors"] [dependencies] -ndarray = "0.11.2" +ndarray="0.12.1" serde = "1.0.59" serde_json = "1.0.17" tvm-runtime = { path = "../../" } diff --git a/rust/runtime/tests/test_nnvm/src/main.rs b/rust/runtime/tests/test_nnvm/src/main.rs index 50179798cd32..8329b9e86c28 100644 --- a/rust/runtime/tests/test_nnvm/src/main.rs +++ b/rust/runtime/tests/test_nnvm/src/main.rs @@ -5,6 +5,7 @@ extern crate ndarray; extern crate serde; extern crate serde_json; +#[macro_use] extern crate tvm_runtime; use std::{collections::HashMap, convert::TryFrom, fs, io::Read}; diff --git a/rust/runtime/tests/test_tvm_basic/Cargo.toml b/rust/runtime/tests/test_tvm_basic/Cargo.toml index 2a753b430c47..561215daf0a7 100644 --- a/rust/runtime/tests/test_tvm_basic/Cargo.toml +++ b/rust/runtime/tests/test_tvm_basic/Cargo.toml @@ -5,7 +5,7 @@ license = "Apache-2.0" authors = ["TVM Contributors"] [dependencies] -ndarray = "0.11.2" +ndarray="0.12.1" tvm-runtime = { path = "../../" } [build-dependencies] From c57bfdcbd04d57eddea7371040f2ef27e32575f0 Mon Sep 17 00:00:00 2001 From: Nick Hynes Date: Mon, 18 Feb 2019 01:34:12 +0000 Subject: [PATCH 03/17] Merge types for frontend --- rust/common/build.rs | 3 +- rust/common/src/errors.rs | 15 +- rust/common/src/lib.rs | 2 +- rust/common/src/packed_func.rs | 238 ++++++++---------- rust/common/src/value.rs | 95 +++++-- rust/frontend/examples/resnet/src/main.rs | 26 +- rust/frontend/src/bytearray.rs | 16 +- rust/frontend/src/context.rs | 17 +- rust/frontend/src/function.rs | 14 +- rust/frontend/src/lib.rs | 1 - rust/frontend/src/ndarray.rs | 32 ++- rust/frontend/src/value.rs | 53 ++-- rust/frontend/tests/basics/src/main.rs | 11 +- rust/frontend/tests/callback/src/bin/array.rs | 18 +- rust/frontend/tests/callback/src/bin/error.rs | 2 +- rust/frontend/tests/callback/src/bin/float.rs | 6 +- rust/frontend/tests/callback/src/bin/int.rs | 4 +- .../frontend/tests/callback/src/bin/string.rs | 18 +- 18 files changed, 316 insertions(+), 255 deletions(-) diff --git a/rust/common/build.rs b/rust/common/build.rs index 90e7a11dd12a..f07e71f0f2bb 100644 --- a/rust/common/build.rs +++ b/rust/common/build.rs @@ -9,6 +9,7 @@ fn main() { println!("cargo:rustc-link-search={}/build", env!("TVM_HOME")); } + // @see rust-bindgen#550 for `blacklist_type` bindgen::Builder::default() .header(format!( "{}/include/tvm/runtime/c_runtime_api.h", @@ -19,7 +20,7 @@ fn main() { env!("TVM_HOME") )) .clang_arg(format!("-I{}/3rdparty/dlpack/include/", env!("TVM_HOME"))) - .blacklist_type("max_align_t") // @see rust-bindgen#550 + .blacklist_type("max_align_t") .layout_tests(false) .derive_partialeq(true) .derive_eq(true) diff --git a/rust/common/src/errors.rs b/rust/common/src/errors.rs index 4af4e3e38720..0f9e9dfee304 100644 --- a/rust/common/src/errors.rs +++ b/rust/common/src/errors.rs @@ -23,20 +23,15 @@ fn type_code_to_string(type_code: &i64) -> String { error_chain! { errors { - // TryFromTVMArgValueError(expected: String, actual: String) { - // description("mismatched types while converting from TVMArgValue") - // display("expected `{}` but given `{}`", expected, actual) - // } - // - // TryFromTVMRetValueError(expected: String, actual: String) { - // description("mismatched types while downcasting TVMRetValue") - // display("invalid downcast: expected `{}` but given `{}`", expected, actual) - // } - TryFromTVMRetValueError(expected_type: String, actual_type_code: i64) { description("mismatched types while downcasting TVMRetValue") display("invalid downcast: expected `{}` but was `{}`", expected_type, type_code_to_string(actual_type_code)) } } + foreign_links { + IntoString(std::ffi::IntoStringError); + ParseInt(std::num::ParseIntError); + Utf8(std::str::Utf8Error); + } } diff --git a/rust/common/src/lib.rs b/rust/common/src/lib.rs index f3d3ac915ee6..a1b243ecf6ab 100644 --- a/rust/common/src/lib.rs +++ b/rust/common/src/lib.rs @@ -28,5 +28,5 @@ pub mod packed_func; pub mod value; pub use errors::*; -pub use ffi::TVMType; +pub use ffi::{TVMContext, TVMType}; pub use packed_func::{TVMArgValue, TVMRetValue}; diff --git a/rust/common/src/packed_func.rs b/rust/common/src/packed_func.rs index 8d02650debdd..ddf0df6eb4ca 100644 --- a/rust/common/src/packed_func.rs +++ b/rust/common/src/packed_func.rs @@ -1,4 +1,4 @@ -use std::{any::Any, convert::TryFrom, marker::PhantomData, os::raw::c_void}; +use std::{any::Any, convert::TryFrom, marker::PhantomData, os::raw::c_void, str::FromStr}; pub use crate::ffi::TVMValue; use crate::{errors::*, ffi::*, value::*}; @@ -53,56 +53,55 @@ macro_rules! ensure_type { /// Creates a conversion to a `TVMArgValue` for a primitive type and DLDataTypeCode. macro_rules! impl_prim_tvm_arg { - ($type:ty, $field:ident, $code:expr, $as:ty) => { - impl<'a> From<$type> for TVMArgValue<'a> { - fn from(val: $type) -> Self { - TVMArgValue { - value: TVMValue { $field: val as $as }, - type_code: $code as i64, - _lifetime: PhantomData, + (@, $field:ident, $code:expr, $as:ty, $( $type:ty ),+) => { + $( + impl<'a> From<$type> for TVMArgValue<'a> { + fn from(val: $type) -> Self { + TVMArgValue { + value: TVMValue { $field: val as $as }, + type_code: $code as i64, + _lifetime: PhantomData, + } } } - } - impl<'a> From<&'a $type> for TVMArgValue<'a> { - fn from(val: &'a $type) -> Self { - TVMArgValue { - value: TVMValue { - $field: val.to_owned() as $as, - }, - type_code: $code as i64, - _lifetime: PhantomData, + impl<'a> From<&'a $type> for TVMArgValue<'a> { + fn from(val: &'a $type) -> Self { + TVMArgValue { + value: TVMValue { + $field: val.to_owned() as $as, + }, + type_code: $code as i64, + _lifetime: PhantomData, + } } } - } - impl<'a> TryFrom> for $type { - type Error = Error; - fn try_from(val: TVMArgValue<'a>) -> Result { - ensure_type!(val, $code); - Ok(unsafe { val.value.$field as $type }) + impl<'a> TryFrom> for $type { + type Error = Error; + fn try_from(val: TVMArgValue<'a>) -> Result { + ensure_type!(val, $code); + Ok(unsafe { val.value.$field as $type }) + } } - } - }; - ($type:ty, $field:ident, $code:expr) => { - impl_prim_tvm_arg!($type, $field, $code, $type); + + impl<'a> TryFrom<&TVMArgValue<'a>> for $type { + type Error = Error; + fn try_from(val: &TVMArgValue<'a>) -> Result { + ensure_type!(val, $code); + Ok(unsafe { val.value.$field as $type }) + } + } + )+ }; - ($type:ty,v_int64) => { - impl_prim_tvm_arg!($type, v_int64, DLDataTypeCode_kDLInt, i64); + (v_int64, $( $type:ty ),+) => { + impl_prim_tvm_arg!(@, v_int64, DLDataTypeCode_kDLInt, i64, $( $type ),+); }; - ($type:ty,v_float64) => { - impl_prim_tvm_arg!($type, v_float64, DLDataTypeCode_kDLFloat, f64); + (v_float64, $( $type:ty ),+) => { + impl_prim_tvm_arg!(@, v_float64, DLDataTypeCode_kDLFloat, f64, $( $type ),+); }; } -impl_prim_tvm_arg!(f32, v_float64); -impl_prim_tvm_arg!(f64, v_float64); -impl_prim_tvm_arg!(i8, v_int64); -impl_prim_tvm_arg!(u8, v_int64); -impl_prim_tvm_arg!(i32, v_int64); -impl_prim_tvm_arg!(u32, v_int64); -impl_prim_tvm_arg!(i64, v_int64); -impl_prim_tvm_arg!(u64, v_int64); -impl_prim_tvm_arg!(isize, v_int64); -impl_prim_tvm_arg!(usize, v_int64); +impl_prim_tvm_arg!(v_float64, f32, f64); +impl_prim_tvm_arg!(v_int64, i8, u8, i32, u32, i64, u64, isize, usize); impl<'a> From<&std::ffi::CString> for TVMArgValue<'a> { fn from(string: &std::ffi::CString) -> Self { @@ -116,6 +115,22 @@ impl<'a> From<&std::ffi::CString> for TVMArgValue<'a> { } } +impl<'a> TryFrom> for &str { + type Error = Error; + fn try_from(arg: TVMArgValue<'a>) -> Result { + ensure_type!(arg, TVMTypeCode_kStr); + Ok(unsafe { std::ffi::CStr::from_ptr(arg.value.v_handle as *const i8) }.to_str()?) + } +} + +impl<'a> TryFrom<&TVMArgValue<'a>> for &str { + type Error = Error; + fn try_from(arg: &TVMArgValue<'a>) -> Result { + ensure_type!(arg, TVMTypeCode_kStr); + Ok(unsafe { std::ffi::CStr::from_ptr(arg.value.v_handle as *const i8) }.to_str()?) + } +} + /// Creates a conversion to a `TVMArgValue` for an object handle. impl<'a, T> From<*const T> for TVMArgValue<'a> { fn from(ptr: *const T) -> Self { @@ -188,125 +203,82 @@ impl<'a, 'v> TryFrom<&'a TVMArgValue<'v>> for TVMType { /// assert_eq!(String::try_from(t).unwrap(), s); /// ``` pub struct TVMRetValue { - /// A primitive return value, if any. - pub prim_value: u64, - /// An object return value, if any. + pub value: TVMValue, pub box_value: Box, - /// The DLDataTypeCode which determines whether `prim_value` or `box_value` is in use. pub type_code: i64, } impl TVMRetValue { pub fn from_tvm_value(value: TVMValue, type_code: i64) -> Self { - unsafe { - Self { - prim_value: match type_code { - 0 | 1 => value.v_int64 as u64, - 2 => value.v_float64 as u64, - 3 | 7 | 8 | 9 | 10 => value.v_handle as u64, - 11 | 12 => value.v_str as u64, - _ => 0, - } as u64, - box_value: box (), - type_code: type_code, - } + Self { + value, + type_code, + box_value: box (), } } pub fn into_tvm_value(self) -> (TVMValue, TVMTypeCode) { - let val = match self.type_code { - 0 | 1 => TVMValue { - v_int64: self.prim_value.clone() as i64, - }, - 2 => TVMValue { - v_float64: self.prim_value.clone() as f64, - }, - 3 | 7 | 8 | 9 | 10 | 13 => TVMValue { - v_handle: Box::into_raw(self.box_value) as *mut c_void, - }, - 11 | 12 => TVMValue { - v_str: Box::into_raw(self.box_value) as *const _, - }, - _ => unreachable!(), - }; - (val, self.type_code as TVMTypeCode) + (self.value, self.type_code as TVMTypeCode) } } impl Default for TVMRetValue { fn default() -> Self { TVMRetValue { - prim_value: 0, - box_value: box (), + value: TVMValue { v_int64: 0 as i64 }, type_code: 0, + box_value: box (), } } } -macro_rules! impl_prim_ret_value { - ($type:ty, $code:expr) => { - impl From<$type> for TVMRetValue { - fn from(val: $type) -> Self { - TVMRetValue { - prim_value: val as u64, - box_value: box (), - type_code: $code, +macro_rules! impl_pod_ret_value { + ($code:expr, $( $ty:ty ),+ ) => { + $( + impl From<$ty> for TVMRetValue { + fn from(val: $ty) -> Self { + Self { + value: val.into(), + type_code: $code as i64, + box_value: box (), + } } } - } - impl TryFrom for $type { - type Error = Error; - fn try_from(ret: TVMRetValue) -> Result<$type> { - if ret.type_code == $code { - Ok(ret.prim_value as $type) - } else { - bail!(ErrorKind::TryFromTVMRetValueError( - stringify!($type).to_string(), - ret.type_code - )) + + impl TryFrom for $ty { + type Error = Error; + fn try_from(ret: TVMRetValue) -> Result<$ty> { + ensure_type!(ret, $code); + Ok(ret.value.into()) } } - } + )+ }; } -macro_rules! impl_boxed_ret_value { - ($type:ty, $code:expr) => { - impl From<$type> for TVMRetValue { - fn from(val: $type) -> Self { - TVMRetValue { - prim_value: 0, - box_value: box val, - type_code: $code, - } - } - } - impl TryFrom for $type { - type Error = Error; - fn try_from(ret: TVMRetValue) -> Result<$type> { - if let Ok(val) = ret.box_value.downcast::<$type>() { - Ok(*val) - } else { - bail!(ErrorKind::TryFromTVMRetValueError( - stringify!($type).to_string(), - ret.type_code - )) - } - } - } - }; +impl_pod_ret_value!(DLDataTypeCode_kDLInt, i8, i16, i32, i64, isize); +impl_pod_ret_value!(DLDataTypeCode_kDLUInt, u8, u16, u32, u64, usize); +impl_pod_ret_value!(DLDataTypeCode_kDLFloat, f32, f64); +impl_pod_ret_value!(TVMTypeCode_kTVMType, TVMType); +impl_pod_ret_value!(TVMTypeCode_kTVMContext, TVMContext); + +impl TryFrom for String { + type Error = Error; + fn try_from(ret: TVMRetValue) -> Result { + ensure_type!(ret, TVMTypeCode_kStr); + Ok(unsafe { std::ffi::CString::from_raw(ret.value.v_str as *mut i8) }.into_string()?) + } } -impl_prim_ret_value!(i8, 0); -impl_prim_ret_value!(u8, 1); -impl_prim_ret_value!(i16, 0); -impl_prim_ret_value!(u16, 1); -impl_prim_ret_value!(i32, 0); -impl_prim_ret_value!(u32, 1); -impl_prim_ret_value!(f32, 2); -impl_prim_ret_value!(i64, 0); -impl_prim_ret_value!(u64, 1); -impl_prim_ret_value!(f64, 2); -impl_prim_ret_value!(isize, 0); -impl_prim_ret_value!(usize, 1); -impl_boxed_ret_value!(String, 11); +impl From for TVMRetValue { + fn from(s: String) -> Self { + let s_box = box std::ffi::CString::new(s).unwrap(); + Self { + value: TVMValue { + v_handle: s_box.as_ptr() as *mut i8 as *mut c_void, + }, + box_value: s_box, + type_code: TVMTypeCode_kStr as i64, + } + } +} diff --git a/rust/common/src/value.rs b/rust/common/src/value.rs index 7fc82c6af665..83ad8cff8f38 100644 --- a/rust/common/src/value.rs +++ b/rust/common/src/value.rs @@ -1,3 +1,5 @@ +use std::str::FromStr; + use crate::ffi::*; impl TVMType { @@ -12,26 +14,23 @@ impl TVMType { /// Implements TVMType conversion from `&str` of general format `{dtype}{bits}x{lanes}` /// such as "int32", "float32" or with lane "float32x1". -impl<'a> From<&'a str> for TVMType { - fn from(type_str: &'a str) -> Self { +impl FromStr for TVMType { + type Err = crate::errors::Error; + fn from_str(type_str: &str) -> Result { if type_str == "bool" { - return TVMType::new(1, 1, 1); + return Ok(TVMType::new(1, 1, 1)); } let mut type_lanes = type_str.split("x"); let typ = type_lanes.next().expect("Missing dtype"); let lanes = type_lanes .next() - .map(|l| u16::from_str_radix(l, 10).expect(&format!("Bad dtype lanes: {}", l))) - .unwrap_or(1); + .map(|l| ::from_str_radix(l, 10)) + .unwrap_or(Ok(1))?; let (type_name, bits) = match typ.find(char::is_numeric) { Some(idx) => { let (name, bits_str) = typ.split_at(idx); - ( - name, - u8::from_str_radix(bits_str, 10) - .expect(&format!("Bad dtype bits: {}", bits_str)), - ) + (name, u8::from_str_radix(bits_str, 10)?) } None => (typ, 32), }; @@ -41,10 +40,10 @@ impl<'a> From<&'a str> for TVMType { "uint" => 1, "float" => 2, "handle" => 3, - _ => unimplemented!(), + _ => return Err(format!("Unknown type {}", type_name).into()), }; - TVMType::new(type_code, bits, lanes) + Ok(TVMType::new(type_code, bits, lanes)) } } @@ -70,23 +69,69 @@ impl std::fmt::Display for TVMType { } } -macro_rules! impl_tvm_val_from_pod { - ($field:ident, $ty:ty) => { - impl From<$ty> for TVMValue { - fn from(val: $ty) -> Self { - TVMValue { $field: val } +macro_rules! impl_pod_tvm_value { + ($field:ident, $field_ty:ty, $( $ty:ty ),+) => { + $( + impl From<$ty> for TVMValue { + fn from(val: $ty) -> Self { + TVMValue { $field: val as $field_ty } + } } - } + + impl From for $ty { + fn from(val: TVMValue) -> Self { + unsafe { val.$field as $ty } + } + } + )+ }; + ($field:ident, $ty:ty) => { + impl_pod_tvm_value!($field, $ty, $ty); + } } -impl_tvm_val_from_pod!(v_type, TVMType); -impl_tvm_val_from_pod!(v_ctx, TVMContext); +impl_pod_tvm_value!(v_int64, i64, i8, u8, i16, u16, i32, u32, i64, u64, isize, usize); +impl_pod_tvm_value!(v_float64, f64, f32, f64); +impl_pod_tvm_value!(v_type, TVMType); +impl_pod_tvm_value!(v_ctx, TVMContext); -impl From for TVMValue { - fn from(dev: DLDeviceType) -> Self { - TVMValue { - v_int64: dev as i64, +macro_rules! impl_tvm_context { + ( $( $dev_type:ident : [ $( $dev_name:ident ),+ ] ),+ ) => { + /// Creates a TVMContext from a string (e.g., "cpu", "gpu", "ext_dev") + impl FromStr for TVMContext { + type Err = crate::errors::Error; + fn from_str(type_str: &str) -> Result { + Ok(Self { + device_type: match type_str { + $( $( stringify!($dev_name) )|+ => $dev_type ),+, + _ => return Err(format!("device {} not supported", type_str).into()), + }, + device_id: 0, + }) + } } - } + + impl TVMContext { + $( + $( + pub fn $dev_name(device_id: usize) -> Self { + Self { + device_type: $dev_type, + device_id: device_id as i32, + } + } + )+ + )+ + } + }; } + +impl_tvm_context!( + DLDeviceType_kDLCPU: [cpu, llvm, stackvm], + DLDeviceType_kDLGPU: [gpu, cuda, nvptx], + DLDeviceType_kDLOpenCL: [cl], + DLDeviceType_kDLMetal: [metal], + DLDeviceType_kDLVPI: [vpi], + DLDeviceType_kDLROCM: [rocm], + DLDeviceType_kDLExtDev: [ext_dev] +); diff --git a/rust/frontend/examples/resnet/src/main.rs b/rust/frontend/examples/resnet/src/main.rs index 869a35b3a3a4..060fad3c164a 100644 --- a/rust/frontend/examples/resnet/src/main.rs +++ b/rust/frontend/examples/resnet/src/main.rs @@ -10,6 +10,7 @@ use std::{ convert::TryInto, fs::{self, File}, path::Path, + str::FromStr, }; use image::{FilterType, GenericImageView}; @@ -44,14 +45,20 @@ fn main() { // make arr shape as [1, 3, 224, 224] acceptable to resnet let arr = arr.insert_axis(Axis(0)); // create input tensor from rust's ndarray - let input = - NDArray::from_rust_ndarray(&arr, TVMContext::cpu(0), TVMType::from("float32")).unwrap(); + let input = NDArray::from_rust_ndarray( + &arr, + TVMContext::cpu(0), + TVMType::from_str("float32").unwrap(), + ) + .unwrap(); println!( "input size is {:?}", input.shape().expect("cannot get the input shape") ); - let graph = - fs::read_to_string(concat!(env!("CARGO_MANIFEST_DIR"), "/deploy_graph.json")).unwrap(); + let graph = std::ffi::CString::new( + fs::read_to_string(concat!(env!("CARGO_MANIFEST_DIR"), "/deploy_graph.json")).unwrap(), + ) + .unwrap(); // load the built module let lib = Module::load(&Path::new(concat!( env!("CARGO_MANIFEST_DIR"), @@ -59,7 +66,7 @@ fn main() { ))) .unwrap(); // get the global TVM graph runtime function - let runtime_create_fn = Function::get("tvm.graph_runtime.create", true).unwrap(); + let runtime_create_fn = Function::get("tvm.graph_runtime.create").unwrap(); let runtime_create_fn_ret = call_packed!( runtime_create_fn, &graph, @@ -85,14 +92,19 @@ fn main() { .get_function("set_input", false) .unwrap(); - call_packed!(set_input_fn, "data", &input).unwrap(); + let data_str = std::ffi::CString::new("data").unwrap(); + call_packed!(set_input_fn, &data_str, &input).unwrap(); // get `run` function from runtime module let ref run_fn = graph_runtime_module.get_function("run", false).unwrap(); // execute the run function. Note that it has no argument call_packed!(run_fn,).unwrap(); // prepare to get the output let output_shape = &mut [1, 1000]; - let output = NDArray::empty(output_shape, TVMContext::cpu(0), TVMType::from("float32")); + let output = NDArray::empty( + output_shape, + TVMContext::cpu(0), + TVMType::from_str("float32").unwrap(), + ); // get the `get_output` function from runtime module let ref get_output_fn = graph_runtime_module .get_function("get_output", false) diff --git a/rust/frontend/src/bytearray.rs b/rust/frontend/src/bytearray.rs index cfcfa06da25c..9274dba862da 100644 --- a/rust/frontend/src/bytearray.rs +++ b/rust/frontend/src/bytearray.rs @@ -3,9 +3,9 @@ //! //! For more detail, please see the example `resnet` in `examples` repository. -use std::os::raw::c_char; +use std::os::raw::{c_char, c_void}; -use tvm_common::ffi; +use tvm_common::{ffi, TVMArgValue}; /// A struct holding TVM byte-array. /// @@ -54,6 +54,18 @@ impl<'a> From<&'a Vec> for TVMByteArray { } } +impl<'a> From<&TVMByteArray> for TVMArgValue<'a> { + fn from(arr: &TVMByteArray) -> Self { + Self { + value: ffi::TVMValue { + v_handle: &arr.inner as *const ffi::TVMByteArray as *const c_void as *mut c_void, + }, + type_code: ffi::TVMTypeCode_kBytes as i64, + _lifetime: std::marker::PhantomData, + } + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/rust/frontend/src/context.rs b/rust/frontend/src/context.rs index 3683e2ec3272..3972653ab163 100644 --- a/rust/frontend/src/context.rs +++ b/rust/frontend/src/context.rs @@ -24,7 +24,10 @@ use std::{ ptr, }; -use tvm_common::ffi; +use tvm_common::{ + ffi::{self, TVMTypeCode_kTVMType, TVMValue}, + TVMArgValue, +}; use crate::{function, Result}; @@ -120,6 +123,18 @@ impl<'a> From<&'a str> for TVMDeviceType { } } +impl<'a> From<&'a TVMDeviceType> for TVMArgValue<'a> { + fn from(dev_type: &'a TVMDeviceType) -> Self { + Self { + value: TVMValue { + v_int64: dev_type.0 as i64, + }, + type_code: ffi::DLDataTypeCode_kDLInt as i64, + _lifetime: std::marker::PhantomData, + } + } +} + /// Represents the underlying device context. Default is cpu. /// /// ## Examples diff --git a/rust/frontend/src/function.rs b/rust/frontend/src/function.rs index d6cb4d227165..4792df82a2ee 100644 --- a/rust/frontend/src/function.rs +++ b/rust/frontend/src/function.rs @@ -175,7 +175,7 @@ impl<'a, 'm> Builder<'a, 'm> { /// Sets an output for a function that requirs a mutable output to be provided. /// See the `basics` in tests for an example. - pub fn set_output(&mut self, ret: T) -> Result<&mut Self> + pub fn set_output(&mut self, mut ret: T) -> Result<&mut Self> where TVMRetValue: From, { @@ -430,17 +430,17 @@ mod tests { #[test] fn get_fn() { - assert!(Function::get(CANARY, true).is_some()); - assert!(Function::get("does not exists!", false).is_none()); + assert!(Function::get(CANARY).is_some()); + assert!(Function::get("does not exists!").is_none()); } #[test] fn provide_args() { + let str_arg = CString::new("test").unwrap(); let mut func = Builder::default(); - func.get_function("tvm.graph_runtime.remote_create", true) + func.get_function("tvm.graph_runtime.remote_create") .args(&[10, 20]) - .arg(&"test".to_owned()); - assert!(func.arg_buf.is_some()); - assert_eq!(func.arg_buf.take().map(|bv| Vec::from(bv).len()), Some(3)); + .arg(&str_arg); + assert_eq!(func.arg_buf.len(), 3); } } diff --git a/rust/frontend/src/lib.rs b/rust/frontend/src/lib.rs index 3789e1d2607e..0579e0de916f 100644 --- a/rust/frontend/src/lib.rs +++ b/rust/frontend/src/lib.rs @@ -11,7 +11,6 @@ //! //! Checkout the `examples` repository for more details. -#![crate_name = "tvm_frontend"] #![allow(non_camel_case_types, unused_unsafe)] #![feature(try_from, try_trait, fn_traits, unboxed_closures, box_syntax)] diff --git a/rust/frontend/src/ndarray.rs b/rust/frontend/src/ndarray.rs index 0992a597c5e7..dca62d1d7a07 100644 --- a/rust/frontend/src/ndarray.rs +++ b/rust/frontend/src/ndarray.rs @@ -23,7 +23,7 @@ //! [`copy_from_buffer`]:struct.NDArray.html#method.copy_from_buffer //! [`copy_to_ctx`]:struct.NDArray.html#method.copy_to_ctx -use std::{convert::TryFrom, mem, os::raw::c_int, ptr, slice}; +use std::{convert::TryFrom, mem, os::raw::c_int, ptr, slice, str::FromStr}; use crate::rust_ndarray::{Array, ArrayD}; use num_traits::Num; @@ -249,7 +249,7 @@ macro_rules! impl_from_ndarray_rustndarray { if nd.shape().is_none() { bail!("{}", ErrorKind::EmptyArray); } - assert_eq!(nd.dtype(), TVMType::from($type_name), "Type mismatch"); + assert_eq!(nd.dtype(), TVMType::from_str($type_name)?, "Type mismatch"); Ok(Array::from_shape_vec(&*nd.shape()?, nd.to_vec::<$type>()?)?) } } @@ -260,7 +260,7 @@ macro_rules! impl_from_ndarray_rustndarray { if nd.shape().is_none() { bail!("{}", ErrorKind::EmptyArray); } - assert_eq!(nd.dtype(), TVMType::from($type_name), "Type mismatch"); + assert_eq!(nd.dtype(), TVMType::from_str($type_name)?, "Type mismatch"); Ok(Array::from_shape_vec(&*nd.shape()?, nd.to_vec::<$type>()?)?) } } @@ -308,7 +308,7 @@ mod tests { fn basics() { let shape = &mut [1, 2, 3]; let ctx = TVMContext::cpu(0); - let ndarray = NDArray::empty(shape, ctx, TVMType::from("int32")); + let ndarray = NDArray::empty(shape, ctx, TVMType::from_str("int32").unwrap()); assert_eq!(ndarray.shape().unwrap(), shape); assert_eq!( ndarray.size().unwrap(), @@ -324,7 +324,7 @@ mod tests { let shape = &mut [4]; let mut data = vec![1i32, 2, 3, 4]; let ctx = TVMContext::cpu(0); - let mut ndarray = NDArray::empty(shape, ctx, TVMType::from("int32")); + let mut ndarray = NDArray::empty(shape, ctx, TVMType::from_str("int32").unwrap()); assert!(ndarray.to_vec::().is_ok()); ndarray.copy_from_buffer(&mut data); assert_eq!(ndarray.shape().unwrap(), shape); @@ -333,7 +333,11 @@ mod tests { assert!(ndarray.is_contiguous().is_ok()); assert_eq!(ndarray.byte_offset(), 0); let mut shape = vec![4]; - let e = NDArray::empty(&mut shape, TVMContext::cpu(0), TVMType::from("int32")); + let e = NDArray::empty( + &mut shape, + TVMContext::cpu(0), + TVMType::from_str("int32").unwrap(), + ); let nd = ndarray.copy_to_ndarray(e); assert!(nd.is_ok()); assert_eq!(nd.unwrap().to_vec::().unwrap(), data); @@ -345,9 +349,13 @@ mod tests { let mut shape = vec![4]; let mut data = vec![1f32, 2., 3., 4.]; let ctx = TVMContext::cpu(0); - let mut nd_float = NDArray::empty(&mut shape, ctx.clone(), TVMType::from("float32")); + let mut nd_float = NDArray::empty( + &mut shape, + ctx.clone(), + TVMType::from_str("float32").unwrap(), + ); nd_float.copy_from_buffer(&mut data); - let empty_int = NDArray::empty(&mut shape, ctx, TVMType::from("int32")); + let empty_int = NDArray::empty(&mut shape, ctx, TVMType::from_str("int32").unwrap()); nd_float.copy_to_ndarray(empty_int).unwrap(); } @@ -356,8 +364,12 @@ mod tests { let a = Array::from_shape_vec((2, 2), vec![1f32, 2., 3., 4.]) .unwrap() .into_dyn(); - let nd = - NDArray::from_rust_ndarray(&a, TVMContext::cpu(0), TVMType::from("float32")).unwrap(); + let nd = NDArray::from_rust_ndarray( + &a, + TVMContext::cpu(0), + TVMType::from_str("float32").unwrap(), + ) + .unwrap(); assert_eq!(nd.shape().unwrap(), &mut [2, 2]); let rnd: ArrayD = ArrayD::try_from(&nd).unwrap(); assert!(rnd.all_close(&a, 1e-8f32)); diff --git a/rust/frontend/src/value.rs b/rust/frontend/src/value.rs index e3b727cdf6e4..b579299d75ae 100644 --- a/rust/frontend/src/value.rs +++ b/rust/frontend/src/value.rs @@ -10,7 +10,8 @@ use tvm_common::{ }; use crate::{ - common_errors::*, Function, Module, NDArray, TVMArgValue, TVMByteArray, TVMContext, TVMRetValue, + common_errors::*, context::TVMContext, Function, Module, NDArray, TVMArgValue, TVMByteArray, + TVMRetValue, }; macro_rules! impl_tvm_val_from_handle { @@ -50,7 +51,9 @@ macro_rules! impl_tvm_val_from_handle { impl From<$ty> for TVMRetValue { fn from(val: $ty) -> TVMRetValue { TVMRetValue { - prim_value: 0, + value: TVMValue { + v_handle: val.handle() as *mut c_void, + }, box_value: box val, type_code: $type_code as i64, } @@ -60,14 +63,8 @@ macro_rules! impl_tvm_val_from_handle { impl TryFrom for $ty { type Error = Error; fn try_from(ret: TVMRetValue) -> Result<$ty> { - if let Ok(handle) = ret.box_value.downcast::<$handle>() { - Ok($ty::new(*handle)) - } else { - bail!(ErrorKind::TryFromTVMRetValueError( - stringify!($type_code).to_string(), - ret.type_code, - )) - } + ensure_type!(ret, $type_code); + Ok($ty::new(unsafe { ret.value.v_handle as $handle })) } } }; @@ -89,30 +86,12 @@ impl<'a> From<&'a TVMByteArray> for TVMValue { } } -impl<'a, 'v> TryFrom<&'a TVMArgValue<'v>> for TVMByteArray { - type Error = Error; - fn try_from(arg: &TVMArgValue<'v>) -> Result { - ensure_type!(arg, ffi::TVMTypeCode_kBytes); - Ok(TVMByteArray::new(unsafe { - *(arg.value.v_handle as *mut ffi::TVMByteArray) - })) - } -} - -impl<'a, 'b> TryFrom<&'b TVMArgValue<'a>> for TVMContext { - type Error = Error; - fn try_from(arg: &TVMArgValue<'a>) -> Result { - ensure_type!(arg, ffi::TVMTypeCode_kTVMContext); - Ok(unsafe { arg.value.v_ctx.into() }) - } -} - macro_rules! impl_boxed_ret_value { ($type:ty, $code:expr) => { impl From<$type> for TVMRetValue { fn from(val: $type) -> Self { TVMRetValue { - prim_value: 0, + value: TVMValue { v_int64: 0 }, box_value: box val, type_code: $code as i64, } @@ -134,14 +113,24 @@ macro_rules! impl_boxed_ret_value { }; } -// impl_boxed_ret_value!(TVMType, ffi::TVMTypeCode_kTVMType); impl_boxed_ret_value!(TVMContext, ffi::TVMTypeCode_kTVMContext); impl_boxed_ret_value!(TVMByteArray, ffi::TVMTypeCode_kBytes); +impl<'a, 'v> TryFrom<&'a TVMArgValue<'v>> for TVMByteArray { + type Error = Error; + fn try_from(arg: &TVMArgValue<'v>) -> Result { + ensure_type!(arg, ffi::TVMTypeCode_kBytes); + Ok(TVMByteArray::new(unsafe { + *(arg.value.v_handle as *mut ffi::TVMByteArray) + })) + } +} + #[cfg(test)] mod tests { use super::*; - use std::convert::TryInto; + use std::{convert::TryInto, str::FromStr}; + use tvm_common::ffi::TVMType; #[test] fn bytearray() { @@ -153,7 +142,7 @@ mod tests { #[test] fn ty() { - let t = TVMType::from("int32"); + let t = TVMType::from_str("int32").unwrap(); let tvm: TVMType = TVMRetValue::from(t).try_into().unwrap(); assert_eq!(tvm, t); } diff --git a/rust/frontend/tests/basics/src/main.rs b/rust/frontend/tests/basics/src/main.rs index 9430829b120e..eedb521dad43 100644 --- a/rust/frontend/tests/basics/src/main.rs +++ b/rust/frontend/tests/basics/src/main.rs @@ -3,7 +3,7 @@ extern crate ndarray as rust_ndarray; extern crate tvm_frontend as tvm; -use std::convert::TryInto; +use std::str::FromStr; use tvm::*; @@ -16,7 +16,7 @@ fn main() { } else { (TVMContext::gpu(0), "gpu") }; - let dtype = TVMType::from("float32"); + let dtype = TVMType::from_str("float32").unwrap(); let mut arr = NDArray::empty(shape, ctx, dtype); arr.copy_from_buffer(data.as_mut_slice()); let mut ret = NDArray::empty(shape, ctx, dtype); @@ -27,14 +27,11 @@ fn main() { if cfg!(feature = "gpu") { fadd.import_module(Module::load(&concat!(env!("OUT_DIR"), "/test_add.ptx")).unwrap()); } - let ret: NDArray = function::Builder::from(&mut fadd) + function::Builder::from(&mut fadd) .arg(&arr) .arg(&arr) - .set_output(ret) - .unwrap() + .arg(&mut ret) .invoke() - .unwrap() - .try_into() .unwrap(); assert_eq!(ret.to_vec::().unwrap(), vec![6f32, 8.0]); diff --git a/rust/frontend/tests/callback/src/bin/array.rs b/rust/frontend/tests/callback/src/bin/array.rs index 81dcadc30851..52fe27417d70 100644 --- a/rust/frontend/tests/callback/src/bin/array.rs +++ b/rust/frontend/tests/callback/src/bin/array.rs @@ -6,7 +6,10 @@ extern crate ndarray as rust_ndarray; extern crate tvm_frontend as tvm; use rust_ndarray::ArrayD; -use std::convert::{TryFrom, TryInto}; +use std::{ + convert::{TryFrom, TryInto}, + str::FromStr, +}; use tvm::*; @@ -16,7 +19,10 @@ fn main() { let mut ret = 0f32; let shape = &mut [2]; for arg in args.iter() { - let e = NDArray::empty(shape, TVMContext::cpu(0), TVMType::from("float32")); + let e = NDArray::empty( + shape, TVMContext::cpu(0), + TVMType::from_str("float32").unwrap() + ); let arg: NDArray = arg.try_into()?; let arr = arg.copy_to_ndarray(e)?; let rnd: ArrayD = ArrayD::try_from(&arr)?; @@ -28,12 +34,16 @@ fn main() { let shape = &mut [2]; let mut data = vec![3f32, 4.0]; - let mut arr = NDArray::empty(shape, TVMContext::cpu(0), TVMType::from("float32")); + let mut arr = NDArray::empty( + shape, + TVMContext::cpu(0), + TVMType::from_str("float32").unwrap(), + ); arr.copy_from_buffer(data.as_mut_slice()); let mut registered = function::Builder::default(); let ret: f32 = registered - .get_function("sum", true) + .get_function("sum") .arg(&arr) .arg(&arr) .invoke() diff --git a/rust/frontend/tests/callback/src/bin/error.rs b/rust/frontend/tests/callback/src/bin/error.rs index f40f0f157815..1a6e7efd74da 100644 --- a/rust/frontend/tests/callback/src/bin/error.rs +++ b/rust/frontend/tests/callback/src/bin/error.rs @@ -19,7 +19,7 @@ fn main() { } let mut registered = function::Builder::default(); - registered.get_function("error", true); + registered.get_function("error"); assert!(registered.func.is_some()); registered.args(&[10, 20]); diff --git a/rust/frontend/tests/callback/src/bin/float.rs b/rust/frontend/tests/callback/src/bin/float.rs index 3070552843d7..e5aa23b76378 100644 --- a/rust/frontend/tests/callback/src/bin/float.rs +++ b/rust/frontend/tests/callback/src/bin/float.rs @@ -11,16 +11,16 @@ fn main() { register_global_func! { fn sum(args: &[TVMArgValue]) -> Result { let mut ret = 0.0; - for arg in args.iter() { + for arg in args.into_iter() { let val: f64 = arg.try_into()?; ret += val; } - Ok(TVMRetValue::from(&ret)) + Ok(TVMRetValue::from(ret)) } } let mut registered = function::Builder::default(); - registered.get_function("sum", true); + registered.get_function("sum"); assert!(registered.func.is_some()); let ret: f64 = registered .args(&[10.0f64, 20.0, 30.0]) diff --git a/rust/frontend/tests/callback/src/bin/int.rs b/rust/frontend/tests/callback/src/bin/int.rs index 30188222054a..724f6082d8be 100644 --- a/rust/frontend/tests/callback/src/bin/int.rs +++ b/rust/frontend/tests/callback/src/bin/int.rs @@ -13,13 +13,13 @@ fn main() { let val: i64 = arg.try_into()?; ret += val; } - Ok(TVMRetValue::from(&ret)) + Ok(TVMRetValue::from(ret)) } tvm::function::register(sum, "mysum".to_owned(), false).unwrap(); let mut registered = function::Builder::default(); - registered.get_function("mysum", true); + registered.get_function("mysum"); assert!(registered.func.is_some()); let ret: i64 = registered .args(&[10, 20, 30]) diff --git a/rust/frontend/tests/callback/src/bin/string.rs b/rust/frontend/tests/callback/src/bin/string.rs index eafee31796bd..346019b1cab5 100644 --- a/rust/frontend/tests/callback/src/bin/string.rs +++ b/rust/frontend/tests/callback/src/bin/string.rs @@ -1,4 +1,4 @@ -#![feature(extern_crate_item_prelude, try_from)] +#![feature(try_from)] #![allow(unused_imports)] #[macro_use] @@ -12,20 +12,22 @@ fn main() { fn concate_str(args: &[TVMArgValue]) -> Result { let mut ret = "".to_string(); for arg in args.iter() { - let val: String = arg.try_into()?; - ret += val.as_str(); + let val: &str = arg.try_into()?; + ret += val; } Ok(TVMRetValue::from(ret)) } } + let a = std::ffi::CString::new("a").unwrap(); + let b = std::ffi::CString::new("b").unwrap(); + let c = std::ffi::CString::new("c").unwrap(); let mut registered = function::Builder::default(); - registered.get_function("concate_str", true); + registered.get_function("concate_str"); assert!(registered.func.is_some()); - let a = "a".to_string(); - let b = "b".to_string(); - let c = "c".to_string(); let ret: String = registered - .args(&[a, b, c]) + .arg(&a) + .arg(&b) + .arg(&c) .invoke() .unwrap() .try_into() From 42645d442a2ebd84767afb46d05966a72c4cb872 Mon Sep 17 00:00:00 2001 From: Nick Hynes Date: Mon, 18 Feb 2019 04:02:59 +0000 Subject: [PATCH 04/17] Fix tests --- tests/scripts/task_rust.sh | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/scripts/task_rust.sh b/tests/scripts/task_rust.sh index be0181b4d95b..6e174208f2c7 100755 --- a/tests/scripts/task_rust.sh +++ b/tests/scripts/task_rust.sh @@ -14,11 +14,11 @@ cargo fmt -- --check # test common cd $RUST_DIR/common -cargo build --features runtime -cargo test --features runtime --tests +cargo build +cargo test --tests -cargo build --features frontend -cargo test --features frontend --tests +cargo build --features bindings +cargo test --features bindings --tests # test runtime cd $RUST_DIR/runtime From 2520796d4ad63013cb4d311ec8b26e0c67f34ca1 Mon Sep 17 00:00:00 2001 From: Nick Hynes Date: Wed, 20 Feb 2019 17:22:29 +0000 Subject: [PATCH 05/17] PackedFunc returns result --- rust/common/src/lib.rs | 7 +++++++ rust/common/src/packed_func.rs | 2 +- rust/runtime/src/graph.rs | 2 +- rust/runtime/src/module.rs | 8 ++++++-- 4 files changed, 15 insertions(+), 4 deletions(-) diff --git a/rust/common/src/lib.rs b/rust/common/src/lib.rs index a1b243ecf6ab..fce9a94f90dd 100644 --- a/rust/common/src/lib.rs +++ b/rust/common/src/lib.rs @@ -19,6 +19,13 @@ pub mod ffi { pub type BackendPackedCFunc = extern "C" fn(args: *const TVMValue, type_codes: *const c_int, num_args: c_int) -> c_int; + + pub fn get_last_error() -> String { + unsafe { std::ffi::CStr::from_ptr(TVMGetLastError()) } + .to_str() + .expect("double fault") + .to_owned() + } } pub mod array; diff --git a/rust/common/src/packed_func.rs b/rust/common/src/packed_func.rs index ddf0df6eb4ca..ffeb53d6c4e6 100644 --- a/rust/common/src/packed_func.rs +++ b/rust/common/src/packed_func.rs @@ -3,7 +3,7 @@ use std::{any::Any, convert::TryFrom, marker::PhantomData, os::raw::c_void, str: pub use crate::ffi::TVMValue; use crate::{errors::*, ffi::*, value::*}; -pub type PackedFunc = Box TVMRetValue + Send + Sync>; +pub type PackedFunc = Box Result>; /// Calls a packed function and returns a `TVMRetValue`. /// diff --git a/rust/runtime/src/graph.rs b/rust/runtime/src/graph.rs index 8f2d5630a5b3..054c7890d81f 100644 --- a/rust/runtime/src/graph.rs +++ b/rust/runtime/src/graph.rs @@ -279,7 +279,7 @@ impl<'m, 't> GraphExecutor<'m, 't> { .iter() .map(|t| t.into()) .collect::>(); - func(args.as_slice()); + func(args.as_slice()).unwrap(); }; op_execs.push(op); } diff --git a/rust/runtime/src/module.rs b/rust/runtime/src/module.rs index 10f17f59015e..dbc0ded589a3 100644 --- a/rust/runtime/src/module.rs +++ b/rust/runtime/src/module.rs @@ -36,7 +36,7 @@ impl Default for SystemLibModule { // @see `WrapPackedFunc` in `llvm_module.cc`. pub(super) fn wrap_backend_packed_func(func: BackendPackedCFunc) -> PackedFunc { box move |args: &[TVMArgValue]| { - func( + let exit_code = func( args.iter() .map(|ref arg| arg.value) .collect::>() @@ -47,7 +47,11 @@ pub(super) fn wrap_backend_packed_func(func: BackendPackedCFunc) -> PackedFunc { .as_ptr() as *const i32, args.len() as i32, ); - TVMRetValue::default() + if exit_code == 0 { + Ok(TVMRetValue::default()) + } else { + Err(tvm_common::ffi::get_last_error().into()) + } } } From 05f1ae30084f5f6ad9ea08ac3c6cf09d335b48ef Mon Sep 17 00:00:00 2001 From: Nick Hynes Date: Wed, 20 Feb 2019 17:54:23 +0000 Subject: [PATCH 06/17] Remove unused imports --- rust/common/src/array.rs | 3 --- rust/common/src/lib.rs | 3 --- rust/common/src/packed_func.rs | 4 ++-- 3 files changed, 2 insertions(+), 8 deletions(-) diff --git a/rust/common/src/array.rs b/rust/common/src/array.rs index 9dbea1d12e1d..e7b75850677d 100644 --- a/rust/common/src/array.rs +++ b/rust/common/src/array.rs @@ -1,10 +1,7 @@ use std::{ any::TypeId, - convert::TryFrom, mem, - ops::{Deref, DerefMut}, os::raw::{c_int, c_void}, - ptr, slice, }; use crate::ffi::{ diff --git a/rust/common/src/lib.rs b/rust/common/src/lib.rs index fce9a94f90dd..b06666545825 100644 --- a/rust/common/src/lib.rs +++ b/rust/common/src/lib.rs @@ -1,9 +1,6 @@ //! This crate contains the refactored basic components required //! for `runtime` and `frontend` TVM crates. -#![crate_name = "tvm_common"] -#![recursion_limit = "1024"] -#![allow(non_camel_case_types, unused_imports)] #![feature(box_syntax, try_from)] #[macro_use] diff --git a/rust/common/src/packed_func.rs b/rust/common/src/packed_func.rs index ffeb53d6c4e6..7942e67caa0c 100644 --- a/rust/common/src/packed_func.rs +++ b/rust/common/src/packed_func.rs @@ -1,7 +1,7 @@ -use std::{any::Any, convert::TryFrom, marker::PhantomData, os::raw::c_void, str::FromStr}; +use std::{any::Any, convert::TryFrom, marker::PhantomData, os::raw::c_void}; pub use crate::ffi::TVMValue; -use crate::{errors::*, ffi::*, value::*}; +use crate::{errors::*, ffi::*}; pub type PackedFunc = Box Result>; From 2f577701e57239b9cd705b521d827b4d86c5a027 Mon Sep 17 00:00:00 2001 From: Nick Hynes Date: Wed, 20 Feb 2019 23:03:36 +0000 Subject: [PATCH 07/17] Use failure in tvm_common --- rust/common/Cargo.toml | 2 +- rust/common/src/errors.rs | 68 +++++++++++++++++++++------ rust/common/src/lib.rs | 9 +--- rust/common/src/packed_func.rs | 84 +++++++++++++++++++--------------- rust/common/src/value.rs | 10 ++-- 5 files changed, 109 insertions(+), 64 deletions(-) diff --git a/rust/common/Cargo.toml b/rust/common/Cargo.toml index 3c095df6c692..5d21ee509b02 100644 --- a/rust/common/Cargo.toml +++ b/rust/common/Cargo.toml @@ -8,7 +8,7 @@ license = "Apache-2.0" bindings = [] [dependencies] -error-chain = { version = "0.12.0", default-features = false } +failure = "0.1.5" ndarray = "0.12.1" [build-dependencies] diff --git a/rust/common/src/errors.rs b/rust/common/src/errors.rs index 0f9e9dfee304..ad72f36433c0 100644 --- a/rust/common/src/errors.rs +++ b/rust/common/src/errors.rs @@ -1,4 +1,4 @@ -//! Error types for `TVMArgValue` and `TVMRetValue` conversions. +use std::fmt; static TYPE_CODE_STRS: [&str; 15] = [ "int", @@ -17,21 +17,63 @@ static TYPE_CODE_STRS: [&str; 15] = [ "NDArrayContainer", "ExtBegin", ]; -fn type_code_to_string(type_code: &i64) -> String { - TYPE_CODE_STRS[*type_code as usize].to_string() + +#[derive(Debug, Fail)] +pub struct ValueDowncastError { + actual_type_code: i64, + expected_type_code: i64, } -error_chain! { - errors { - TryFromTVMRetValueError(expected_type: String, actual_type_code: i64) { - description("mismatched types while downcasting TVMRetValue") - display("invalid downcast: expected `{}` but was `{}`", - expected_type, type_code_to_string(actual_type_code)) +impl ValueDowncastError { + pub fn new(actual_type_code: i64, expected_type_code: i64) -> Self { + Self { + actual_type_code, + expected_type_code, } } - foreign_links { - IntoString(std::ffi::IntoStringError); - ParseInt(std::num::ParseIntError); - Utf8(std::str::Utf8Error); +} + +impl fmt::Display for ValueDowncastError { + fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + write!( + formatter, + "Could not downcast TVMValue: expected `{}` but was {}", + TYPE_CODE_STRS[self.actual_type_code as usize], + TYPE_CODE_STRS[self.expected_type_code as usize] + ) + } +} + +#[derive(Debug, Fail)] +#[fail(display = "Function call `{}` returned error: {}", context, message)] +pub struct FuncCallError { + context: String, + message: String, +} + +impl FuncCallError { + pub fn get_with_context(context: String) -> Self { + Self { + context, + message: unsafe { std::ffi::CStr::from_ptr(crate::ffi::TVMGetLastError()) } + .to_str() + .expect("double fault") + .to_owned(), + } } } + +// error_chain! { +// errors { +// TryFromTVMRetValueError(expected_type: String, actual_type_code: i64) { +// description("mismatched types while downcasting TVMRetValue") +// display("invalid downcast: expected `{}` but was `{}`", +// expected_type, type_code_to_string(actual_type_code)) +// } +// } +// foreign_links { +// IntoString(std::ffi::IntoStringError); +// ParseInt(std::num::ParseIntError); +// Utf8(std::str::Utf8Error); +// } +// } diff --git a/rust/common/src/lib.rs b/rust/common/src/lib.rs index b06666545825..337103f19ccb 100644 --- a/rust/common/src/lib.rs +++ b/rust/common/src/lib.rs @@ -4,7 +4,7 @@ #![feature(box_syntax, try_from)] #[macro_use] -extern crate error_chain; +extern crate failure; /// Unified ffi module for both runtime and frontend crates. pub mod ffi { @@ -16,13 +16,6 @@ pub mod ffi { pub type BackendPackedCFunc = extern "C" fn(args: *const TVMValue, type_codes: *const c_int, num_args: c_int) -> c_int; - - pub fn get_last_error() -> String { - unsafe { std::ffi::CStr::from_ptr(TVMGetLastError()) } - .to_str() - .expect("double fault") - .to_owned() - } } pub mod array; diff --git a/rust/common/src/packed_func.rs b/rust/common/src/packed_func.rs index 7942e67caa0c..c2068570eb0e 100644 --- a/rust/common/src/packed_func.rs +++ b/rust/common/src/packed_func.rs @@ -1,9 +1,11 @@ use std::{any::Any, convert::TryFrom, marker::PhantomData, os::raw::c_void}; +use failure::Error; + pub use crate::ffi::TVMValue; -use crate::{errors::*, ffi::*}; +use crate::ffi::*; -pub type PackedFunc = Box Result>; +pub type PackedFunc = Box Result>; /// Calls a packed function and returns a `TVMRetValue`. /// @@ -41,25 +43,26 @@ impl<'a> TVMArgValue<'a> { #[macro_export] macro_rules! ensure_type { - ($val:ident, $actual:expr) => { + ($val:ident, $expected_type_code:expr) => { ensure!( - $val.type_code == $actual as i64, - "Could not downcast value. Expected type code `{}`, got `{}`", - $actual, - $val.type_code + $val.type_code == $expected_type_code as i64, + crate::errors::ValueDowncastError::new( + $val.type_code as i64, + $expected_type_code as i64 + ) ); }; } /// Creates a conversion to a `TVMArgValue` for a primitive type and DLDataTypeCode. macro_rules! impl_prim_tvm_arg { - (@, $field:ident, $code:expr, $as:ty, $( $type:ty ),+) => { + ($type_code:ident, $field:ident, $field_type:ty, [ $( $type:ty ),+ ] ) => { $( impl<'a> From<$type> for TVMArgValue<'a> { fn from(val: $type) -> Self { TVMArgValue { - value: TVMValue { $field: val as $as }, - type_code: $code as i64, + value: TVMValue { $field: val as $field_type }, + type_code: $type_code as i64, _lifetime: PhantomData, } } @@ -68,40 +71,45 @@ macro_rules! impl_prim_tvm_arg { fn from(val: &'a $type) -> Self { TVMArgValue { value: TVMValue { - $field: val.to_owned() as $as, + $field: val.to_owned() as $field_type, }, - type_code: $code as i64, + type_code: $type_code as i64, _lifetime: PhantomData, } } } impl<'a> TryFrom> for $type { - type Error = Error; - fn try_from(val: TVMArgValue<'a>) -> Result { - ensure_type!(val, $code); + type Error = Error; + fn try_from(val: TVMArgValue<'a>) -> Result { + ensure_type!(val, $type_code); Ok(unsafe { val.value.$field as $type }) } } impl<'a> TryFrom<&TVMArgValue<'a>> for $type { - type Error = Error; - fn try_from(val: &TVMArgValue<'a>) -> Result { - ensure_type!(val, $code); + type Error = Error; + fn try_from(val: &TVMArgValue<'a>) -> Result { + ensure_type!(val, $type_code); Ok(unsafe { val.value.$field as $type }) } } )+ }; - (v_int64, $( $type:ty ),+) => { - impl_prim_tvm_arg!(@, v_int64, DLDataTypeCode_kDLInt, i64, $( $type ),+); - }; - (v_float64, $( $type:ty ),+) => { - impl_prim_tvm_arg!(@, v_float64, DLDataTypeCode_kDLFloat, f64, $( $type ),+); - }; } -impl_prim_tvm_arg!(v_float64, f32, f64); -impl_prim_tvm_arg!(v_int64, i8, u8, i32, u32, i64, u64, isize, usize); +impl_prim_tvm_arg!(DLDataTypeCode_kDLFloat, v_float64, f64, [f32, f64]); +impl_prim_tvm_arg!( + DLDataTypeCode_kDLInt, + v_int64, + i64, + [i8, i16, i32, i64, isize] +); +impl_prim_tvm_arg!( + DLDataTypeCode_kDLUInt, + v_int64, + i64, + [u8, u16, u32, u64, usize] +); impl<'a> From<&std::ffi::CString> for TVMArgValue<'a> { fn from(string: &std::ffi::CString) -> Self { @@ -117,7 +125,7 @@ impl<'a> From<&std::ffi::CString> for TVMArgValue<'a> { impl<'a> TryFrom> for &str { type Error = Error; - fn try_from(arg: TVMArgValue<'a>) -> Result { + fn try_from(arg: TVMArgValue<'a>) -> Result { ensure_type!(arg, TVMTypeCode_kStr); Ok(unsafe { std::ffi::CStr::from_ptr(arg.value.v_handle as *const i8) }.to_str()?) } @@ -125,7 +133,7 @@ impl<'a> TryFrom> for &str { impl<'a> TryFrom<&TVMArgValue<'a>> for &str { type Error = Error; - fn try_from(arg: &TVMArgValue<'a>) -> Result { + fn try_from(arg: &TVMArgValue<'a>) -> Result { ensure_type!(arg, TVMTypeCode_kStr); Ok(unsafe { std::ffi::CStr::from_ptr(arg.value.v_handle as *const i8) }.to_str()?) } @@ -183,7 +191,7 @@ impl<'a> From<&'a DLTensor> for TVMArgValue<'a> { impl<'a, 'v> TryFrom<&'a TVMArgValue<'v>> for TVMType { type Error = Error; - fn try_from(arg: &'a TVMArgValue<'v>) -> Result { + fn try_from(arg: &'a TVMArgValue<'v>) -> Result { ensure_type!(arg, TVMTypeCode_kTVMType); Ok(unsafe { arg.value.v_type.into() }) } @@ -233,7 +241,7 @@ impl Default for TVMRetValue { } macro_rules! impl_pod_ret_value { - ($code:expr, $( $ty:ty ),+ ) => { + ($code:expr, [ $( $ty:ty ),+ ] ) => { $( impl From<$ty> for TVMRetValue { fn from(val: $ty) -> Self { @@ -246,8 +254,8 @@ macro_rules! impl_pod_ret_value { } impl TryFrom for $ty { - type Error = Error; - fn try_from(ret: TVMRetValue) -> Result<$ty> { + type Error = Error; + fn try_from(ret: TVMRetValue) -> Result<$ty, Self::Error> { ensure_type!(ret, $code); Ok(ret.value.into()) } @@ -256,15 +264,15 @@ macro_rules! impl_pod_ret_value { }; } -impl_pod_ret_value!(DLDataTypeCode_kDLInt, i8, i16, i32, i64, isize); -impl_pod_ret_value!(DLDataTypeCode_kDLUInt, u8, u16, u32, u64, usize); -impl_pod_ret_value!(DLDataTypeCode_kDLFloat, f32, f64); -impl_pod_ret_value!(TVMTypeCode_kTVMType, TVMType); -impl_pod_ret_value!(TVMTypeCode_kTVMContext, TVMContext); +impl_pod_ret_value!(DLDataTypeCode_kDLInt, [i8, i16, i32, i64, isize]); +impl_pod_ret_value!(DLDataTypeCode_kDLUInt, [u8, u16, u32, u64, usize]); +impl_pod_ret_value!(DLDataTypeCode_kDLFloat, [f32, f64]); +impl_pod_ret_value!(TVMTypeCode_kTVMType, [TVMType]); +impl_pod_ret_value!(TVMTypeCode_kTVMContext, [TVMContext]); impl TryFrom for String { type Error = Error; - fn try_from(ret: TVMRetValue) -> Result { + fn try_from(ret: TVMRetValue) -> Result { ensure_type!(ret, TVMTypeCode_kStr); Ok(unsafe { std::ffi::CString::from_raw(ret.value.v_str as *mut i8) }.into_string()?) } diff --git a/rust/common/src/value.rs b/rust/common/src/value.rs index 83ad8cff8f38..c7c040b0060e 100644 --- a/rust/common/src/value.rs +++ b/rust/common/src/value.rs @@ -1,5 +1,7 @@ use std::str::FromStr; +use failure::Error; + use crate::ffi::*; impl TVMType { @@ -15,7 +17,7 @@ impl TVMType { /// Implements TVMType conversion from `&str` of general format `{dtype}{bits}x{lanes}` /// such as "int32", "float32" or with lane "float32x1". impl FromStr for TVMType { - type Err = crate::errors::Error; + type Err = Error; fn from_str(type_str: &str) -> Result { if type_str == "bool" { return Ok(TVMType::new(1, 1, 1)); @@ -40,7 +42,7 @@ impl FromStr for TVMType { "uint" => 1, "float" => 2, "handle" => 3, - _ => return Err(format!("Unknown type {}", type_name).into()), + _ => return Err(format_err!("Unknown type {}", type_name)), }; Ok(TVMType::new(type_code, bits, lanes)) @@ -99,12 +101,12 @@ macro_rules! impl_tvm_context { ( $( $dev_type:ident : [ $( $dev_name:ident ),+ ] ),+ ) => { /// Creates a TVMContext from a string (e.g., "cpu", "gpu", "ext_dev") impl FromStr for TVMContext { - type Err = crate::errors::Error; + type Err = Error; fn from_str(type_str: &str) -> Result { Ok(Self { device_type: match type_str { $( $( stringify!($dev_name) )|+ => $dev_type ),+, - _ => return Err(format!("device {} not supported", type_str).into()), + _ => return Err(format_err!("device {} not supported", type_str).into()), }, device_id: 0, }) From 909b3ba341c2ce641dcca6d55f47d3c7330ea08d Mon Sep 17 00:00:00 2001 From: Nick Hynes Date: Wed, 20 Feb 2019 23:03:46 +0000 Subject: [PATCH 08/17] Use failure in tvm_runtime --- rust/runtime/Cargo.toml | 2 +- rust/runtime/src/allocator.rs | 4 +- rust/runtime/src/array.rs | 7 ++- rust/runtime/src/errors.rs | 80 +++++++++++++++----------- rust/runtime/src/graph.rs | 102 ++++++++++++++-------------------- rust/runtime/src/lib.rs | 2 +- rust/runtime/src/module.rs | 9 ++- rust/runtime/src/sgx.rs | 2 +- rust/runtime/src/threading.rs | 5 +- rust/runtime/src/workspace.rs | 13 +++-- 10 files changed, 114 insertions(+), 112 deletions(-) diff --git a/rust/runtime/Cargo.toml b/rust/runtime/Cargo.toml index dcb7de5836f3..ae73ae721224 100644 --- a/rust/runtime/Cargo.toml +++ b/rust/runtime/Cargo.toml @@ -15,7 +15,7 @@ sgx = ["nom/alloc"] [dependencies] bounded-spsc-queue = "0.4.0" -error-chain = { version = "0.12.0", default-features = false } +failure = "0.1.5" itertools = "0.7.8" lazy_static = "1.1.0" ndarray="0.12.1" diff --git a/rust/runtime/src/allocator.rs b/rust/runtime/src/allocator.rs index 5f77037e25f3..fba325012784 100644 --- a/rust/runtime/src/allocator.rs +++ b/rust/runtime/src/allocator.rs @@ -3,7 +3,7 @@ use alloc::alloc::{self, Layout}; #[cfg(not(target_env = "sgx"))] use std::alloc::{self, Layout}; -use crate::errors::*; +use failure::Error; const DEFAULT_ALIGN_BYTES: usize = 4; @@ -15,7 +15,7 @@ pub struct Allocation { impl Allocation { /// Allocates a chunk of memory of `size` bytes with optional alignment. - pub fn new(size: usize, align: Option) -> Result { + pub fn new(size: usize, align: Option) -> Result { let alignment = align.unwrap_or(DEFAULT_ALIGN_BYTES); let layout = Layout::from_size_align(size, alignment)?; let ptr = unsafe { alloc::alloc(layout.clone()) }; diff --git a/rust/runtime/src/array.rs b/rust/runtime/src/array.rs index eceba761535d..3bb02f12c866 100644 --- a/rust/runtime/src/array.rs +++ b/rust/runtime/src/array.rs @@ -1,5 +1,6 @@ use std::{convert::TryFrom, mem, os::raw::c_void, ptr, slice}; +use failure::Error; use ndarray; use tvm_common::{ array::{DataType, TVMContext}, @@ -9,7 +10,7 @@ use tvm_common::{ }, }; -use crate::{allocator::Allocation, errors::*}; +use crate::allocator::Allocation; /// A `Storage` is a container which holds `Tensor` data. #[derive(PartialEq)] @@ -22,7 +23,7 @@ pub enum Storage<'a> { } impl<'a> Storage<'a> { - pub fn new(size: usize, align: Option) -> Result> { + pub fn new(size: usize, align: Option) -> Result, Error> { Ok(Storage::Owned(Allocation::new(size, align)?)) } @@ -258,7 +259,7 @@ macro_rules! impl_ndarray_try_from_tensor { ($type:ty, $dtype:expr) => { impl<'a, 't> TryFrom<&'a Tensor<'t>> for ndarray::ArrayD<$type> { type Error = Error; - fn try_from(tensor: &'a Tensor) -> Result> { + fn try_from(tensor: &'a Tensor) -> Result, Error> { ensure!( tensor.dtype == $dtype, "Cannot convert Tensor with dtype {:?} to ndarray", diff --git a/rust/runtime/src/errors.rs b/rust/runtime/src/errors.rs index ab56d2580fd2..5688f3699e8b 100644 --- a/rust/runtime/src/errors.rs +++ b/rust/runtime/src/errors.rs @@ -1,35 +1,49 @@ -#[cfg(target_env = "sgx")] -use alloc::alloc; -#[cfg(not(target_env = "sgx"))] -use std::alloc; -use std::num; - -use ndarray; -use serde_json; - -error_chain! { - errors { - GraphFormatError(msg: String) { - description("unable to load graph") - display("could not load graph json: {}", msg) - } - - LoadGraphParamsError(msg: String) { - description("unable to load graph params") - display("could not load graph params: {}", msg) - } - } - foreign_links { - Alloc(alloc::AllocErr); - GraphDeserialize(serde_json::Error); - ParseInt(num::ParseIntError); - ShapeError(ndarray::ShapeError); - CommonError(tvm_common::errors::Error); - } +#[derive(Debug, Fail)] +pub enum GraphFormatError { + #[fail(display = "Could not parse graph json")] + Parse(#[fail(cause)] failure::Error), + #[fail(display = "Could not parse graph params")] + Params, + #[fail(display = "{} is missing attr: {}", 0, 1)] + MissingAttr(String, String), + #[fail(display = "Missing field: {}", 0)] + MissingField(&'static str), + #[fail(display = "Invalid DLType: {}", 0)] + InvalidDLType(String), } -impl From for Error { - fn from(_err: alloc::LayoutErr) -> Error { - Error::from_kind(ErrorKind::Msg("Layout error".to_string())) - } -} +// #[cfg(target_env = "sgx")] +// use alloc::alloc; +// #[cfg(not(target_env = "sgx"))] +// use std::alloc; +// use std::num; +// +// use ndarray; +// use serde_json; + +// error_chain! { +// errors { +// GraphFormatError(msg: String) { +// description("unable to load graph") +// display("could not load graph json: {}", msg) +// } +// +// LoadGraphParamsError(msg: String) { +// description("unable to load graph params") +// display("could not load graph params: {}", msg) +// } +// } +// foreign_links { +// Alloc(alloc::AllocErr); +// GraphDeserialize(serde_json::Error); +// ParseInt(num::ParseIntError); +// ShapeError(ndarray::ShapeError); +// CommonError(tvm_common::errors::Error); +// } +// } +// +// impl From for Error { +// fn from(_err: alloc::LayoutErr) -> Error { +// Error::from_kind(ErrorKind::Msg("Layout error".to_string())) +// } +// } diff --git a/rust/runtime/src/graph.rs b/rust/runtime/src/graph.rs index 054c7890d81f..94194d45160b 100644 --- a/rust/runtime/src/graph.rs +++ b/rust/runtime/src/graph.rs @@ -1,19 +1,17 @@ use std::{cmp, collections::HashMap, convert::TryFrom, iter::FromIterator, mem, str}; +use failure::Error; use nom::{alpha1, digit1, le_i32, le_i64, le_u16, le_u32, le_u64, le_u8, types::CompleteStr}; use serde; use serde_json; - -use crate::{ - errors::{Error, ErrorKind, Result}, - Module, Storage, Tensor, -}; use tvm_common::{ array::{DataType, TVMContext}, ffi::{DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLDataTypeCode_kDLUInt, DLTensor}, TVMArgValue, }; +use crate::{errors::GraphFormatError, Module, Storage, Tensor}; + // @see `kTVMNDArrayMagic` in `ndarray.h` const _NDARRAY_MAGIC: u64 = 0xDD5E40F096B4A13F; // @see `kTVMNDArrayListMagic` in `graph_runtime.h` @@ -44,28 +42,26 @@ pub struct Entry { } impl Graph { - fn entry_index(&self, entry: &Entry) -> Result { + fn entry_index(&self, entry: &Entry) -> Result { self.node_row_ptr .as_ref() .map(|nrp| nrp[entry.id] + entry.index) - .ok_or("Missing node_row_ptr.".into()) + .ok_or_else(|| GraphFormatError::MissingField("node_row_ptr")) } /// Attempt to deserialize a JSON attribute to a type `T`. - fn get_attr(&self, attr: &str) -> Result { + fn get_attr(&self, attr: &str) -> Result { Ok(serde_json::from_value::( self.attrs .as_ref() - .ok_or(ErrorKind::GraphFormatError( - "Missing graph attrs".to_string(), - ))? + .ok_or(GraphFormatError::MissingField("attrs"))? .get(attr) - .ok_or(ErrorKind::GraphFormatError(format!( - "Missing {} attr", - attr - )))? + .ok_or_else(|| { + GraphFormatError::MissingAttr("graph".to_string(), attr.to_string()) + })? .to_owned(), - )?) + ) + .map_err(|err| GraphFormatError::Parse(err.into()))?) } } @@ -84,39 +80,31 @@ struct NodeAttrs { flatten_data: bool, } +macro_rules! get_node_attr { + ($node:expr, $attrs:ident, $attr:literal) => { + $attrs + .get($attr) + .ok_or_else(|| GraphFormatError::MissingAttr($node.to_owned(), $attr.to_owned())) + }; +} + impl Node { - fn parse_attrs(&self) -> Result { + fn parse_attrs(&self) -> Result { let attrs = self .attrs .as_ref() - .ok_or(format!("Missing node.attrs for `{}`", self.name))?; - let func_name = attrs - .get("func_name") - .ok_or(format!("Node `{}` is missing attrs.func_name", self.name))? - .to_string(); - let num_outputs = attrs - .get("num_outputs") - .ok_or(format!("Node `{}` is missing attrs.num_outputs", self.name))? - .parse::()?; - let flatten_data = attrs - .get("flatten_data") - .ok_or(format!( - "Node `{}` is missing attrs.flatten_data", - self.name - ))? - .parse::()? - == 1; + .ok_or_else(|| GraphFormatError::MissingAttr(self.name.clone(), "attrs".to_owned()))?; Ok(NodeAttrs { - func_name, - num_outputs, - flatten_data, + func_name: get_node_attr!(self.name, attrs, "func_name")?.to_owned(), + num_outputs: get_node_attr!(self.name, attrs, "num_outputs")?.parse::()?, + flatten_data: get_node_attr!(self.name, attrs, "flatten_data")?.parse::()? == 1, }) } } impl<'a> TryFrom<&'a String> for Graph { type Error = Error; - fn try_from(graph_json: &String) -> Result { + fn try_from(graph_json: &String) -> Result { let graph = serde_json::from_str(graph_json)?; Ok(graph) } @@ -124,7 +112,7 @@ impl<'a> TryFrom<&'a String> for Graph { impl<'a> TryFrom<&'a str> for Graph { type Error = Error; - fn try_from(graph_json: &'a str) -> Result { + fn try_from(graph_json: &'a str) -> Result { let graph = serde_json::from_str(graph_json)?; Ok(graph) } @@ -164,7 +152,7 @@ pub struct GraphExecutor<'m, 't> { unsafe impl<'m, 't> Send for GraphExecutor<'m, 't> {} impl<'m, 't> GraphExecutor<'m, 't> { - pub fn new(graph: Graph, lib: &'m M) -> Result { + pub fn new(graph: Graph, lib: &'m M) -> Result { let tensors = Self::setup_storages(&graph)?; Ok(GraphExecutor { op_execs: Self::setup_op_execs(&graph, lib, &tensors)?, @@ -181,7 +169,7 @@ impl<'m, 't> GraphExecutor<'m, 't> { } /// Allocates `Storages` for each `storage_id` and returns `Tensor`s to hold each output. - fn setup_storages<'a>(graph: &'a Graph) -> Result>> { + fn setup_storages<'a>(graph: &'a Graph) -> Result>, Error> { let storage_ids = graph.get_attr::<(String, Vec)>("storage_id")?.1; let shapes = graph.get_attr::<(String, Vec>)>("shape")?.1; let dtypes = graph @@ -192,13 +180,10 @@ impl<'m, 't> GraphExecutor<'m, 't> { if let Ok((_, dtype)) = tvm_str_to_type(CompleteStr(dltype)) { Ok(dtype) } else { - Err(ErrorKind::GraphFormatError( - format!("Invalid dltype: {}", dltype).to_string(), - ) - .into()) + Err(GraphFormatError::InvalidDLType(dltype.to_string())) } }) - .collect::>>()?; + .collect::, GraphFormatError>>()?; let align = dtypes.iter().map(|dtype| dtype.bits() as usize).max(); let mut storage_num_bytes = vec![0usize; *storage_ids.iter().max().unwrap_or(&1) + 1]; @@ -211,7 +196,7 @@ impl<'m, 't> GraphExecutor<'m, 't> { let mut storages: Vec = storage_num_bytes .into_iter() .map(|nbytes| Storage::new(nbytes, align)) - .collect::>>()?; + .collect::, Error>>()?; let tensors = izip!(storage_ids, shapes, dtypes) .map(|(storage_id, shape, dtype)| { @@ -236,7 +221,7 @@ impl<'m, 't> GraphExecutor<'m, 't> { graph: &Graph, lib: &'m M, tensors: &Vec>, - ) -> Result>> { + ) -> Result>, Error> { ensure!(graph.node_row_ptr.is_some(), "Missing node_row_ptr."); let node_row_ptr = graph.node_row_ptr.as_ref().unwrap(); @@ -254,9 +239,10 @@ impl<'m, 't> GraphExecutor<'m, 't> { continue; } - let func = lib - .get_function(&attrs.func_name) - .ok_or(format!("Missing function {}", attrs.func_name))?; + let func = lib.get_function(&attrs.func_name).ok_or(format_err!( + "Library is missing function {}", + attrs.func_name + ))?; let arg_indices = node .inputs .iter() @@ -272,7 +258,7 @@ impl<'m, 't> GraphExecutor<'m, 't> { DLTensor::from(tensor) }) }) - .collect::>>() + .collect::, Error>>() .unwrap(); let op: Box = box move || { let args = dl_tensors @@ -436,17 +422,15 @@ named!( ); /// Loads a param dict saved using `nnvm.compiler.save_param_dict`. -pub fn load_param_dict(bytes: &[u8]) -> Result> { +pub fn load_param_dict(bytes: &[u8]) -> Result, GraphFormatError> { if let Ok((remaining_bytes, param_dict)) = parse_param_dict(bytes) { - if remaining_bytes.len() > 0 { - bail!(ErrorKind::LoadGraphParamsError("extra input".to_string())) - } else { + if remaining_bytes.len() == 0 { Ok(param_dict) + } else { + Err(GraphFormatError::Params) } } else { - bail!(ErrorKind::LoadGraphParamsError( - "invalid parameters file".to_string() - )) + Err(GraphFormatError::Params) } } diff --git a/rust/runtime/src/lib.rs b/rust/runtime/src/lib.rs index 9716cda3bf46..f8d87ed5a0bb 100644 --- a/rust/runtime/src/lib.rs +++ b/rust/runtime/src/lib.rs @@ -25,7 +25,7 @@ extern crate bounded_spsc_queue; #[cfg(target_env = "sgx")] extern crate core; #[macro_use] -extern crate error_chain; +extern crate failure; #[macro_use] extern crate itertools; #[macro_use] diff --git a/rust/runtime/src/module.rs b/rust/runtime/src/module.rs index dbc0ded589a3..6e914ce1ae25 100644 --- a/rust/runtime/src/module.rs +++ b/rust/runtime/src/module.rs @@ -1,6 +1,7 @@ use std::{ collections::HashMap, convert::AsRef, ffi::CStr, os::raw::c_char, string::String, sync::Mutex, }; + use tvm_common::{ ffi::BackendPackedCFunc, packed_func::{PackedFunc, TVMArgValue, TVMRetValue, TVMValue}, @@ -23,7 +24,7 @@ impl Module for SystemLibModule { .lock() .unwrap() .get(name.as_ref()) - .map(|func| wrap_backend_packed_func(func.to_owned())) + .map(|func| wrap_backend_packed_func(name.as_ref().to_owned(), func.to_owned())) } } @@ -34,7 +35,7 @@ impl Default for SystemLibModule { } // @see `WrapPackedFunc` in `llvm_module.cc`. -pub(super) fn wrap_backend_packed_func(func: BackendPackedCFunc) -> PackedFunc { +pub(super) fn wrap_backend_packed_func(func_name: String, func: BackendPackedCFunc) -> PackedFunc { box move |args: &[TVMArgValue]| { let exit_code = func( args.iter() @@ -50,7 +51,9 @@ pub(super) fn wrap_backend_packed_func(func: BackendPackedCFunc) -> PackedFunc { if exit_code == 0 { Ok(TVMRetValue::default()) } else { - Err(tvm_common::ffi::get_last_error().into()) + Err(tvm_common::errors::FuncCallError::get_with_context( + func_name.clone(), + )) } } } diff --git a/rust/runtime/src/sgx.rs b/rust/runtime/src/sgx.rs index 1c5cb1ae4a40..f8570355db60 100644 --- a/rust/runtime/src/sgx.rs +++ b/rust/runtime/src/sgx.rs @@ -14,7 +14,7 @@ macro_rules! tvm_ocall { ($func: expr) => { match $func { 0 => Ok(()), - err => Err(format!("SGX error: {}", err)), + err => Err(format_err!("SGX error: {}", err)), } }; } diff --git a/rust/runtime/src/threading.rs b/rust/runtime/src/threading.rs index e6b0a6d11698..3c34a25e0014 100644 --- a/rust/runtime/src/threading.rs +++ b/rust/runtime/src/threading.rs @@ -18,10 +18,9 @@ use std::{ use std::{collections::VecDeque, ptr, sync::Mutex}; use bounded_spsc_queue::{self, Producer}; +use failure::Error; use tvm_common::ffi::TVMParallelGroupEnv; -use crate::errors::*; - #[cfg(target_env = "sgx")] use super::{sgx::ocall_packed_func, TVMArgValue, TVMRetValue}; @@ -63,7 +62,7 @@ impl Job { } /// Waits for all tasks in this `Job` to be completed. - fn wait(&self) -> Result<()> { + fn wait(&self) -> Result<(), Error> { while self.pending.load(Ordering::Acquire) > 0 { #[cfg(not(target_env = "sgx"))] thread::yield_now(); diff --git a/rust/runtime/src/workspace.rs b/rust/runtime/src/workspace.rs index a12a27e4c47c..1e29ec179bca 100644 --- a/rust/runtime/src/workspace.rs +++ b/rust/runtime/src/workspace.rs @@ -4,8 +4,9 @@ use std::{ ptr, }; -use super::allocator::Allocation; -use crate::errors::*; +use failure::Error; + +use crate::allocator::Allocation; const WS_ALIGN: usize = 64; // taken from `kTempAllocaAlignment` in `device_api.h` @@ -24,13 +25,13 @@ impl WorkspacePool { } } - fn alloc_new(&mut self, size: usize) -> Result<*mut u8> { + fn alloc_new(&mut self, size: usize) -> Result<*mut u8, Error> { self.workspaces.push(Allocation::new(size, Some(WS_ALIGN))?); self.in_use.push(self.workspaces.len() - 1); Ok(self.workspaces[self.workspaces.len() - 1].as_mut_ptr()) } - fn alloc(&mut self, size: usize) -> Result<*mut u8> { + fn alloc(&mut self, size: usize) -> Result<*mut u8, Error> { if self.free.len() == 0 { return self.alloc_new(size); } @@ -60,7 +61,7 @@ impl WorkspacePool { } } - fn free(&mut self, ptr: *mut u8) -> Result<()> { + fn free(&mut self, ptr: *mut u8) -> Result<(), Error> { let mut ws_idx = None; for i in 0..self.in_use.len() { let idx = self.in_use[i]; @@ -72,7 +73,7 @@ impl WorkspacePool { } Ok(self .free - .push(ws_idx.ok_or("Tried to free nonexistent workspace.")?)) + .push(ws_idx.ok_or(format_err!("Tried to free nonexistent workspace."))?)) } } From b6ed7360a854e198d1d1ba3dc9ba8099366e1793 Mon Sep 17 00:00:00 2001 From: Nick Hynes Date: Mon, 25 Feb 2019 08:49:32 +0000 Subject: [PATCH 09/17] Use optimal PackedFunc --- rust/common/src/lib.rs | 2 +- rust/common/src/packed_func.rs | 2 +- rust/runtime/src/module.rs | 25 ++++++++++++++----------- 3 files changed, 16 insertions(+), 13 deletions(-) diff --git a/rust/common/src/lib.rs b/rust/common/src/lib.rs index 337103f19ccb..8e980e158ba8 100644 --- a/rust/common/src/lib.rs +++ b/rust/common/src/lib.rs @@ -1,7 +1,7 @@ //! This crate contains the refactored basic components required //! for `runtime` and `frontend` TVM crates. -#![feature(box_syntax, try_from)] +#![feature(box_syntax, trait_alias, try_from)] #[macro_use] extern crate failure; diff --git a/rust/common/src/packed_func.rs b/rust/common/src/packed_func.rs index c2068570eb0e..c51c05e4359a 100644 --- a/rust/common/src/packed_func.rs +++ b/rust/common/src/packed_func.rs @@ -5,7 +5,7 @@ use failure::Error; pub use crate::ffi::TVMValue; use crate::ffi::*; -pub type PackedFunc = Box Result>; +pub trait PackedFunc = Fn(&[TVMArgValue]) -> Result + Send + Sync; /// Calls a packed function and returns a `TVMRetValue`. /// diff --git a/rust/runtime/src/module.rs b/rust/runtime/src/module.rs index 6e914ce1ae25..f5ca12e02fc7 100644 --- a/rust/runtime/src/module.rs +++ b/rust/runtime/src/module.rs @@ -8,23 +8,23 @@ use tvm_common::{ }; pub trait Module { - fn get_function>(&self, name: S) -> Option; + fn get_function>(&self, name: S) -> Option<&(dyn PackedFunc)>; } pub struct SystemLibModule; lazy_static! { - static ref SYSTEM_LIB_FUNCTIONS: Mutex> = + static ref SYSTEM_LIB_FUNCTIONS: Mutex> = Mutex::new(HashMap::new()); } impl Module for SystemLibModule { - fn get_function>(&self, name: S) -> Option { + fn get_function>(&self, name: S) -> Option<&(dyn PackedFunc)> { SYSTEM_LIB_FUNCTIONS .lock() .unwrap() .get(name.as_ref()) - .map(|func| wrap_backend_packed_func(name.as_ref().to_owned(), func.to_owned())) + .map(|f| *f) } } @@ -35,8 +35,11 @@ impl Default for SystemLibModule { } // @see `WrapPackedFunc` in `llvm_module.cc`. -pub(super) fn wrap_backend_packed_func(func_name: String, func: BackendPackedCFunc) -> PackedFunc { - box move |args: &[TVMArgValue]| { +pub(super) fn wrap_backend_packed_func( + func_name: String, + func: BackendPackedCFunc, +) -> Box { + Box::new(move |args: &[TVMArgValue]| { let exit_code = func( args.iter() .map(|ref arg| arg.value) @@ -55,7 +58,7 @@ pub(super) fn wrap_backend_packed_func(func_name: String, func: BackendPackedCFu func_name.clone(), )) } - } + }) } #[no_mangle] @@ -64,9 +67,9 @@ pub extern "C" fn TVMBackendRegisterSystemLibSymbol( func: BackendPackedCFunc, ) -> i32 { let name = unsafe { CStr::from_ptr(cname).to_str().unwrap() }; - SYSTEM_LIB_FUNCTIONS - .lock() - .unwrap() - .insert(name.to_string(), func); + SYSTEM_LIB_FUNCTIONS.lock().unwrap().insert( + name.to_string(), + &*Box::leak(wrap_backend_packed_func(name.to_string(), func)), + ); return 0; } From 5da3ca3fb61cec58dee739deb99934710453bae3 Mon Sep 17 00:00:00 2001 From: Nick Hynes Date: Mon, 25 Feb 2019 09:31:08 +0000 Subject: [PATCH 10/17] Add GetLastError to runtime --- rust/runtime/src/lib.rs | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/rust/runtime/src/lib.rs b/rust/runtime/src/lib.rs index f8d87ed5a0bb..0afa6c2ac5d2 100644 --- a/rust/runtime/src/lib.rs +++ b/rust/runtime/src/lib.rs @@ -65,12 +65,22 @@ pub use self::{array::*, errors::*, graph::*, module::*, threading::*, workspace #[cfg(target_env = "sgx")] use self::sgx::ocall_packed_func; +lazy_static! { + static ref LAST_ERROR: std::sync::RwLock> = + std::sync::RwLock::new(None); +} + #[no_mangle] pub extern "C" fn TVMAPISetLastError(cmsg: *const i8) { - #[cfg(not(target_env = "sgx"))] - unsafe { - panic!(std::ffi::CStr::from_ptr(cmsg).to_str().unwrap()); - } + *LAST_ERROR.write().unwrap() = Some(unsafe { std::ffi::CStr::from_ptr(cmsg) }); #[cfg(target_env = "sgx")] ocall_packed!("__sgx_set_last_error__", cmsg); } + +#[no_mangle] +pub extern "C" fn TVMGetLastError() -> *const std::os::raw::c_char { + match *LAST_ERROR.read().unwrap() { + Some(err) => err.as_ptr(), + None => std::ptr::null(), + } +} From 022a1449e514c7dc8f8b0bbee3a3eea6363085b7 Mon Sep 17 00:00:00 2001 From: Nick Hynes Date: Mon, 25 Feb 2019 09:31:46 +0000 Subject: [PATCH 11/17] Fix lints in runtime tests --- rust/runtime/tests/test_nnvm/src/main.rs | 1 - rust/runtime/tests/test_tvm_basic/src/main.rs | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/rust/runtime/tests/test_nnvm/src/main.rs b/rust/runtime/tests/test_nnvm/src/main.rs index 8329b9e86c28..50179798cd32 100644 --- a/rust/runtime/tests/test_nnvm/src/main.rs +++ b/rust/runtime/tests/test_nnvm/src/main.rs @@ -5,7 +5,6 @@ extern crate ndarray; extern crate serde; extern crate serde_json; -#[macro_use] extern crate tvm_runtime; use std::{collections::HashMap, convert::TryFrom, fs, io::Read}; diff --git a/rust/runtime/tests/test_tvm_basic/src/main.rs b/rust/runtime/tests/test_tvm_basic/src/main.rs index f14fbec8c439..621315ddd34f 100644 --- a/rust/runtime/tests/test_tvm_basic/src/main.rs +++ b/rust/runtime/tests/test_tvm_basic/src/main.rs @@ -17,6 +17,6 @@ fn main() { let mut a_dl: DLTensor = (&mut a).into(); let mut b_dl: DLTensor = (&mut b).into(); let mut c_dl: DLTensor = (&mut c).into(); - call_packed!(add, &mut a_dl, &mut b_dl, &mut c_dl); + call_packed!(add, &mut a_dl, &mut b_dl, &mut c_dl).unwrap(); assert!(c.all_close(&e, 1e-8f32)); } From 97660dfea85b5768e9f508de9ddfa94c64617f13 Mon Sep 17 00:00:00 2001 From: Nick Hynes Date: Mon, 25 Feb 2019 09:42:21 +0000 Subject: [PATCH 12/17] box_syntax --- rust/runtime/src/module.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/rust/runtime/src/module.rs b/rust/runtime/src/module.rs index f5ca12e02fc7..636c4e8ff5cf 100644 --- a/rust/runtime/src/module.rs +++ b/rust/runtime/src/module.rs @@ -39,7 +39,7 @@ pub(super) fn wrap_backend_packed_func( func_name: String, func: BackendPackedCFunc, ) -> Box { - Box::new(move |args: &[TVMArgValue]| { + box move |args: &[TVMArgValue]| { let exit_code = func( args.iter() .map(|ref arg| arg.value) @@ -58,7 +58,7 @@ pub(super) fn wrap_backend_packed_func( func_name.clone(), )) } - }) + } } #[no_mangle] From c88257640ae7801fcbda4a9c873738e535abf953 Mon Sep 17 00:00:00 2001 From: Nick Hynes Date: Sat, 23 Mar 2019 08:10:36 +0000 Subject: [PATCH 13/17] Fix warnings --- rust/common/src/lib.rs | 2 +- rust/runtime/src/graph.rs | 12 ++++++------ rust/runtime/src/lib.rs | 1 - 3 files changed, 7 insertions(+), 8 deletions(-) diff --git a/rust/common/src/lib.rs b/rust/common/src/lib.rs index 8e980e158ba8..966655e802f8 100644 --- a/rust/common/src/lib.rs +++ b/rust/common/src/lib.rs @@ -1,7 +1,7 @@ //! This crate contains the refactored basic components required //! for `runtime` and `frontend` TVM crates. -#![feature(box_syntax, trait_alias, try_from)] +#![feature(box_syntax, trait_alias)] #[macro_use] extern crate failure; diff --git a/rust/runtime/src/graph.rs b/rust/runtime/src/graph.rs index 94194d45160b..6e00d9c7a14c 100644 --- a/rust/runtime/src/graph.rs +++ b/rust/runtime/src/graph.rs @@ -333,7 +333,7 @@ impl<'m, 't> GraphExecutor<'m, 't> { } } -/// Converts a string to TVM DLDataTypeCode. @see `String2TVMType` in packed_func.h +// Converts a string to TVM DLDataTypeCode. @see `String2TVMType` in packed_func.h named!( tvm_str_to_type, do_parse!( @@ -356,7 +356,7 @@ named!( ) ); -/// Converts a bytes to String. +// Converts a bytes to String. named!( name, map_res!(length_bytes!(le_u64), |b: &[u8]| String::from_utf8( @@ -364,7 +364,7 @@ named!( )) ); -/// Parses a TVMContext +// Parses a TVMContext named!( tvm_ctx<&[u8], TVMContext>, do_parse!( @@ -374,7 +374,7 @@ named!( ) ); -/// Parses a DataType +// Parses a DataType named!( data_type<&[u8], DataType>, do_parse!( @@ -385,7 +385,7 @@ named!( ) ); -/// Parses a Tensor from a TVM array file. +// Parses a Tensor from a TVM array file. named!( tensor, do_parse!( @@ -409,7 +409,7 @@ named!( ) ); -/// Parses a graph params dict from a params binary file. +// Parses a graph params dict from a params binary file. named!( parse_param_dict>, do_parse!( diff --git a/rust/runtime/src/lib.rs b/rust/runtime/src/lib.rs index 0afa6c2ac5d2..5a0faf7a0ac1 100644 --- a/rust/runtime/src/lib.rs +++ b/rust/runtime/src/lib.rs @@ -14,7 +14,6 @@ allocator_api, box_syntax, fn_traits, - try_from, unboxed_closures, vec_remove_item )] From 89750eb0d0529d733a4c29a67ea22f01a015ab65 Mon Sep 17 00:00:00 2001 From: Nick Hynes Date: Sat, 23 Mar 2019 19:18:36 +0000 Subject: [PATCH 14/17] Use failure in frontend --- rust/common/src/packed_func.rs | 18 +++-- rust/common/src/value.rs | 2 +- rust/frontend/Cargo.toml | 2 +- rust/frontend/src/context.rs | 8 +- rust/frontend/src/errors.rs | 68 +++++----------- rust/frontend/src/function.rs | 55 +++++++------ rust/frontend/src/lib.rs | 12 +-- rust/frontend/src/module.rs | 34 +++++--- rust/frontend/src/ndarray.rs | 77 +++++++++++-------- rust/frontend/src/value.rs | 16 ++-- rust/frontend/tests/basics/src/main.rs | 2 - rust/frontend/tests/callback/src/bin/array.rs | 5 +- rust/frontend/tests/callback/src/bin/error.rs | 14 ++-- rust/frontend/tests/callback/src/bin/float.rs | 5 +- rust/frontend/tests/callback/src/bin/int.rs | 5 +- .../frontend/tests/callback/src/bin/string.rs | 5 +- 16 files changed, 163 insertions(+), 165 deletions(-) diff --git a/rust/common/src/packed_func.rs b/rust/common/src/packed_func.rs index c51c05e4359a..8ced5cc7e3bf 100644 --- a/rust/common/src/packed_func.rs +++ b/rust/common/src/packed_func.rs @@ -5,7 +5,8 @@ use failure::Error; pub use crate::ffi::TVMValue; use crate::ffi::*; -pub trait PackedFunc = Fn(&[TVMArgValue]) -> Result + Send + Sync; +pub trait PackedFunc = + Fn(&[TVMArgValue]) -> Result + Send + Sync; /// Calls a packed function and returns a `TVMRetValue`. /// @@ -46,7 +47,7 @@ macro_rules! ensure_type { ($val:ident, $expected_type_code:expr) => { ensure!( $val.type_code == $expected_type_code as i64, - crate::errors::ValueDowncastError::new( + $crate::errors::ValueDowncastError::new( $val.type_code as i64, $expected_type_code as i64 ) @@ -274,18 +275,23 @@ impl TryFrom for String { type Error = Error; fn try_from(ret: TVMRetValue) -> Result { ensure_type!(ret, TVMTypeCode_kStr); - Ok(unsafe { std::ffi::CString::from_raw(ret.value.v_str as *mut i8) }.into_string()?) + let cs = unsafe { std::ffi::CString::from_raw(ret.value.v_handle as *mut i8) }; + let ret_str = cs.clone().into_string(); + if cfg!(feature = "bindings") { + std::mem::forget(cs); // TVM C++ takes ownership of CString. (@see TVMFuncCall) + } + Ok(ret_str?) } } impl From for TVMRetValue { fn from(s: String) -> Self { - let s_box = box std::ffi::CString::new(s).unwrap(); + let cs = std::ffi::CString::new(s).unwrap(); Self { value: TVMValue { - v_handle: s_box.as_ptr() as *mut i8 as *mut c_void, + v_str: cs.into_raw() as *mut i8, }, - box_value: s_box, + box_value: box (), type_code: TVMTypeCode_kStr as i64, } } diff --git a/rust/common/src/value.rs b/rust/common/src/value.rs index c7c040b0060e..739a2bf19b73 100644 --- a/rust/common/src/value.rs +++ b/rust/common/src/value.rs @@ -45,7 +45,7 @@ impl FromStr for TVMType { _ => return Err(format_err!("Unknown type {}", type_name)), }; - Ok(TVMType::new(type_code, bits, lanes)) + Ok(dbg!(TVMType::new(type_code, bits, lanes))) } } diff --git a/rust/frontend/Cargo.toml b/rust/frontend/Cargo.toml index cf09133e2a15..eb1f5b8db021 100644 --- a/rust/frontend/Cargo.toml +++ b/rust/frontend/Cargo.toml @@ -15,7 +15,7 @@ name = "tvm_frontend" crate-type = ["dylib"] [dependencies] -error-chain = "0.12.0" +failure = "0.1.5" lazy_static = "1.1.0" ndarray = "0.12.1" num-traits = "0.2" diff --git a/rust/frontend/src/context.rs b/rust/frontend/src/context.rs index 3972653ab163..ff25414c58b3 100644 --- a/rust/frontend/src/context.rs +++ b/rust/frontend/src/context.rs @@ -24,12 +24,14 @@ use std::{ ptr, }; +use failure::Error; + use tvm_common::{ - ffi::{self, TVMTypeCode_kTVMType, TVMValue}, + ffi::{self, TVMValue}, TVMArgValue, }; -use crate::{function, Result}; +use crate::function; /// Device type can be from a supported device name. See the supported devices /// in [TVM](https://github.com/dmlc/tvm). @@ -215,7 +217,7 @@ impl TVMContext { } /// Synchronize the context stream. - pub fn sync(&self) -> Result<()> { + pub fn sync(&self) -> Result<(), Error> { check_call!(ffi::TVMSynchronize( self.device_type.0 as i32, self.device_id as i32, diff --git a/rust/frontend/src/errors.rs b/rust/frontend/src/errors.rs index 6cef88c7f285..96a70caf59c0 100644 --- a/rust/frontend/src/errors.rs +++ b/rust/frontend/src/errors.rs @@ -1,54 +1,26 @@ -//! This module implements TVM custom [`Error`], [`ErrorKind`] and [`Result`] types. +pub use failure::Error; -use std::{ffi, option}; +#[derive(Debug, Fail)] +#[fail(display = "Cannot convert from an empty array.")] +pub struct EmptyArrayError; -use crate::{common_errors, rust_ndarray}; - -error_chain! { - errors { - EmptyArray { - description("cannot convert from an empty array") - } - - NullHandle(name: String) { - description("null handle") - display("requested `{}` handle is null", name) - } - - FunctionNotFound { - description("function not found") - display("function was not set in `function::Builder`") - } - - TypeMismatch(expected: String, found: String) { - description("type mismatch!") - display("expected type `{}`, but found `{}`", expected, found) - } - - MissingShapeError { - description("ndarray `shape()` returns `None`") - display("called `Option::unwrap()` on a `None` value") - } - - AtMostOneReturn { - description("TVM functions accept at most one return value") - } - - } +#[derive(Debug, Fail)] +#[fail(display = "Handle `{}` is null.", name)] +pub struct NullHandleError { + pub name: String, +} - foreign_links { - ShapeError(rust_ndarray::ShapeError); - NulError(ffi::NulError); - IntoStringError(ffi::IntoStringError); - } +#[derive(Debug, Fail)] +#[fail(display = "Function was not set in `function::Builder`")] +pub struct FunctionNotFoundError; - links { - CommonError(common_errors::Error, common_errors::ErrorKind); - } +#[derive(Debug, Fail)] +#[fail(display = "Expected type `{}` but found `{}`", expected, actual)] +pub struct TypeMismatchError { + pub expected: String, + pub actual: String, } -impl From for Error { - fn from(_err: option::NoneError) -> Self { - ErrorKind::MissingShapeError.into() - } -} +#[derive(Debug, Fail)] +#[fail(display = "Missing NDArray shape.")] +pub struct MissingShapeError; diff --git a/rust/frontend/src/function.rs b/rust/frontend/src/function.rs index 4792df82a2ee..004796fb48ee 100644 --- a/rust/frontend/src/function.rs +++ b/rust/frontend/src/function.rs @@ -15,9 +15,12 @@ use std::{ sync::Mutex, }; +use failure::Error; + use crate::{ + errors, ffi::{self, TVMValue}, - ErrorKind, Module, Result, TVMArgValue, TVMRetValue, + Module, TVMArgValue, TVMRetValue, }; lazy_static! { @@ -175,23 +178,17 @@ impl<'a, 'm> Builder<'a, 'm> { /// Sets an output for a function that requirs a mutable output to be provided. /// See the `basics` in tests for an example. - pub fn set_output(&mut self, mut ret: T) -> Result<&mut Self> + pub fn set_output(&mut self, ret: T) -> &mut Self where TVMRetValue: From, { - if self.ret_buf.is_none() { - self.ret_buf = Some(ret.into()) - } else { - bail!(ErrorKind::AtMostOneReturn) - } - Ok(self) + self.ret_buf = Some(ret.into()); + self } /// Calls the function that created from `Builder`. - pub fn invoke(&mut self) -> Result { - if self.func.is_none() { - bail!("{}", ErrorKind::FunctionNotFound); - } + pub fn invoke(&mut self) -> Result { + ensure!(self.func.is_some(), errors::FunctionNotFoundError); let num_args = self.arg_buf.len(); let (mut values, mut type_codes): (Vec, Vec) = self @@ -200,10 +197,10 @@ impl<'a, 'm> Builder<'a, 'm> { .map(|tvm_arg| (tvm_arg.value, tvm_arg.type_code as ffi::TVMTypeCode)) .unzip(); - let mut ret_val = TVMValue { v_int64: 4 }; + let mut ret_val = unsafe { std::mem::uninitialized::() }; let mut ret_type_code = 0; check_call!(ffi::TVMFuncCall( - self.func?.handle, + self.func.ok_or(errors::FunctionNotFoundError)?.handle, values.as_mut_ptr(), type_codes.as_mut_ptr() as *mut i32, num_args as c_int, @@ -211,8 +208,7 @@ impl<'a, 'm> Builder<'a, 'm> { &mut ret_type_code as *mut _ )); - let ret = unsafe { TVMRetValue::from_tvm_value(ret_val.into(), ret_type_code as i64) }; - Ok(ret) + Ok(unsafe { TVMRetValue::from_tvm_value(ret_val, ret_type_code as i64) }) } } @@ -246,7 +242,8 @@ unsafe extern "C" fn tvm_callback( let mut local_args: Vec = Vec::new(); let mut value = mem::uninitialized::(); let mut tcode = mem::uninitialized::(); - let rust_fn = mem::transmute::<*mut c_void, fn(&[TVMArgValue]) -> Result>(fhandle); + let rust_fn = + mem::transmute::<*mut c_void, fn(&[TVMArgValue]) -> Result>(fhandle); for i in 0..len { value = args_list[i]; tcode = type_codes_list[i]; @@ -279,13 +276,14 @@ unsafe extern "C" fn tvm_callback( } unsafe extern "C" fn tvm_callback_finalizer(fhandle: *mut c_void) { - let rust_fn = mem::transmute::<*mut c_void, fn(&[TVMArgValue]) -> Result>(fhandle); + let rust_fn = + mem::transmute::<*mut c_void, fn(&[TVMArgValue]) -> Result>(fhandle); mem::drop(rust_fn); } -fn convert_to_tvm_func(f: fn(&[TVMArgValue]) -> Result) -> Function { +fn convert_to_tvm_func(f: fn(&[TVMArgValue]) -> Result) -> Function { let mut fhandle = ptr::null_mut() as ffi::TVMFunctionHandle; - let resource_handle = f as *mut fn(&[TVMArgValue]) -> Result; + let resource_handle = f as *mut fn(&[TVMArgValue]) -> Result; check_call!(ffi::TVMFuncCreateFromCFunc( Some(tvm_callback), resource_handle as *mut c_void, @@ -296,7 +294,7 @@ fn convert_to_tvm_func(f: fn(&[TVMArgValue]) -> Result) -> Function } /// Registers a Rust function with signature -/// `fn(&[TVMArgValue]) -> Result` +/// `fn(&[TVMArgValue]) -> Result` /// as a **global TVM packed function** from frontend to TVM backend. /// /// Use [`register_global_func`] if overriding an existing global TVM function @@ -307,7 +305,7 @@ fn convert_to_tvm_func(f: fn(&[TVMArgValue]) -> Result) -> Function /// ``` /// use std::convert::TryInto; /// -/// fn sum(args: &[TVMArgValue]) -> Result { +/// fn sum(args: &[TVMArgValue]) -> Result { /// let mut ret = 0i64; /// for arg in args.iter() { /// let arg: i64 = arg.try_into()?; @@ -325,18 +323,17 @@ fn convert_to_tvm_func(f: fn(&[TVMArgValue]) -> Result) -> Function /// assert_eq!(ret, 60); /// ``` pub fn register>( - f: fn(&[TVMArgValue]) -> Result, + f: fn(&[TVMArgValue]) -> Result, name: S, override_: bool, -) -> Result<()> { +) -> Result<(), Error> { let func = convert_to_tvm_func(f); let name = CString::new(name.as_ref())?; check_call!(ffi::TVMFuncRegisterGlobal( - name.as_ref().as_ptr() as *const c_char, + name.into_raw(), func.handle(), override_ as c_int )); - mem::forget(name); Ok(()) } @@ -350,7 +347,7 @@ pub fn register>( /// use std::convert::TryInto; /// /// register_global_func! { -/// fn sum(args: &[TVMArgValue]) -> Result { +/// fn sum(args: &[TVMArgValue]) -> Result { /// let mut ret = 0f64; /// for arg in args.iter() { /// let arg: f64 = arg.try_into()?; @@ -371,12 +368,12 @@ pub fn register>( macro_rules! register_global_func { { $(#[$m:meta])* - fn $fn_name:ident($args:ident : &[TVMArgValue]) -> Result { + fn $fn_name:ident($args:ident : &[TVMArgValue]) -> Result { $($code:tt)* } } => {{ $(#[$m])* - fn $fn_name($args: &[TVMArgValue]) -> Result { + fn $fn_name($args: &[TVMArgValue]) -> Result { $($code)* } diff --git a/rust/frontend/src/lib.rs b/rust/frontend/src/lib.rs index 0579e0de916f..cd1561090144 100644 --- a/rust/frontend/src/lib.rs +++ b/rust/frontend/src/lib.rs @@ -12,21 +12,23 @@ //! Checkout the `examples` repository for more details. #![allow(non_camel_case_types, unused_unsafe)] -#![feature(try_from, try_trait, fn_traits, unboxed_closures, box_syntax)] +#![feature(try_trait, fn_traits, unboxed_closures, box_syntax)] #[macro_use] -extern crate error_chain; -extern crate tvm_common; +extern crate failure; #[macro_use] extern crate lazy_static; extern crate ndarray as rust_ndarray; extern crate num_traits; +extern crate tvm_common; use std::{ ffi::{CStr, CString}, str, }; +use failure::Error; + pub use crate::{ bytearray::TVMByteArray, context::{TVMContext, TVMDeviceType}, @@ -95,8 +97,8 @@ mod tests { #[test] fn set_error() { - let err = ErrorKind::EmptyArray; + let err = errors::EmptyArrayError; set_last_error(&err.into()); - assert_eq!(get_last_error().trim(), ErrorKind::EmptyArray.to_string()); + assert_eq!(get_last_error().trim(), errors::EmptyArrayError.to_string()); } } diff --git a/rust/frontend/src/module.rs b/rust/frontend/src/module.rs index 575cfdaae0fd..9c27387520dc 100644 --- a/rust/frontend/src/module.rs +++ b/rust/frontend/src/module.rs @@ -8,9 +8,10 @@ use std::{ ptr, }; +use failure::Error; use tvm_common::ffi; -use crate::{function::Function, ErrorKind, Result}; +use crate::{errors, function::Function}; const ENTRY_FUNC: &'static str = "__tvm_main__"; @@ -40,7 +41,7 @@ impl Module { } /// Gets a function by name from a registered module. - pub fn get_function(&self, name: &str, query_import: bool) -> Result { + pub fn get_function(&self, name: &str, query_import: bool) -> Result { let name = CString::new(name)?; let mut fhandle = ptr::null_mut() as ffi::TVMFunctionHandle; check_call!(ffi::TVMModGetFunction( @@ -49,11 +50,13 @@ impl Module { query_import as c_int, &mut fhandle as *mut _ )); - if fhandle.is_null() { - bail!(ErrorKind::NullHandle(format!("{}", name.into_string()?))) - } else { - Ok(Function::new(fhandle)) - } + ensure!( + !fhandle.is_null(), + errors::NullHandleError { + name: format!("{}", name.into_string()?) + } + ); + Ok(Function::new(fhandle)) } /// Imports a dependent module such as `.ptx` for gpu. @@ -62,10 +65,21 @@ impl Module { } /// Loads a module shared library from path. - pub fn load>(path: &P) -> Result { - let ext = CString::new(path.as_ref().extension()?.to_str()?)?; + pub fn load>(path: &P) -> Result { + let ext = CString::new( + path.as_ref() + .extension() + .unwrap_or(std::ffi::OsStr::new("")) + .to_str() + .ok_or_else(|| { + format_err!("Bad module load path: `{}`.", path.as_ref().display()) + })?, + )?; let func = Function::get("module._LoadFromFile").expect("API function always exists"); - let cpath = CString::new(path.as_ref().to_str()?)?; + let cpath = + CString::new(path.as_ref().to_str().ok_or_else(|| { + format_err!("Bad module load path: `{}`.", path.as_ref().display()) + })?)?; let ret: Module = call_packed!(func, &cpath, &ext)?.try_into()?; Ok(ret) } diff --git a/rust/frontend/src/ndarray.rs b/rust/frontend/src/ndarray.rs index dca62d1d7a07..d676c61bfde5 100644 --- a/rust/frontend/src/ndarray.rs +++ b/rust/frontend/src/ndarray.rs @@ -25,11 +25,12 @@ use std::{convert::TryFrom, mem, os::raw::c_int, ptr, slice, str::FromStr}; -use crate::rust_ndarray::{Array, ArrayD}; +use failure::Error; use num_traits::Num; +use rust_ndarray::{Array, ArrayD}; use tvm_common::{ffi, TVMType}; -use crate::{Error, ErrorKind, Result, TVMByteArray, TVMContext}; +use crate::{errors, TVMByteArray, TVMContext}; /// See the [`module-level documentation`](../ndarray/index.html) for more details. /// @@ -98,12 +99,13 @@ impl NDArray { } /// Shows whether the underlying ndarray is contiguous in memory or not. - pub fn is_contiguous(&self) -> Result { + pub fn is_contiguous(&self) -> Result { Ok(match self.strides() { None => true, Some(strides) => { - // MissingShapeError in case shape is not determined - self.shape()? + // errors::MissingShapeError in case shape is not determined + self.shape() + .ok_or(errors::MissingShapeError)? .iter() .zip(strides) .rfold( @@ -137,14 +139,16 @@ impl NDArray { /// assert_eq!(ndarray.shape(), Some(shape)); /// assert_eq!(ndarray.to_vec::().unwrap(), data); /// ``` - pub fn to_vec(&self) -> Result> { - if self.shape().is_none() { - bail!("{}", ErrorKind::EmptyArray); - } - let earr = NDArray::empty(self.shape()?, TVMContext::cpu(0), self.dtype()); + pub fn to_vec(&self) -> Result, Error> { + ensure!(self.shape().is_some(), errors::EmptyArrayError); + let earr = NDArray::empty( + self.shape().ok_or(errors::MissingShapeError)?, + TVMContext::cpu(0), + self.dtype(), + ); let target = self.copy_to_ndarray(earr)?; let arr = unsafe { *(target.handle) }; - let sz = self.size()? as usize; + let sz = self.size().ok_or(errors::MissingShapeError)?; let mut v: Vec = Vec::with_capacity(sz * mem::size_of::()); unsafe { v.as_mut_ptr() @@ -155,7 +159,7 @@ impl NDArray { } /// Converts the NDArray to [`TVMByteArray`]. - pub fn to_bytearray(&self) -> Result { + pub fn to_bytearray(&self) -> Result { let v = self.to_vec::()?; Ok(TVMByteArray::from(&v)) } @@ -183,14 +187,14 @@ impl NDArray { } /// Copies the NDArray to another target NDArray. - pub fn copy_to_ndarray(&self, target: NDArray) -> Result { - if self.dtype() != target.dtype() { + pub fn copy_to_ndarray(&self, target: NDArray) -> Result { + if dbg!(self.dtype()) != dbg!(target.dtype()) { bail!( "{}", - ErrorKind::TypeMismatch( - format!("{}", self.dtype().to_string()), - format!("{}", target.dtype().to_string()), - ) + errors::TypeMismatchError { + expected: format!("{}", self.dtype().to_string()), + actual: format!("{}", target.dtype().to_string()), + } ); } check_call!(ffi::TVMArrayCopyFromTo( @@ -202,8 +206,12 @@ impl NDArray { } /// Copies the NDArray to a target context. - pub fn copy_to_ctx(&self, target: &TVMContext) -> Result { - let tmp = NDArray::empty(self.shape()?, target.clone(), self.dtype()); + pub fn copy_to_ctx(&self, target: &TVMContext) -> Result { + let tmp = NDArray::empty( + self.shape().ok_or(errors::MissingShapeError)?, + target.clone(), + self.dtype(), + ); let copy = self.copy_to_ndarray(tmp)?; Ok(copy) } @@ -213,11 +221,14 @@ impl NDArray { rnd: &ArrayD, ctx: TVMContext, dtype: TVMType, - ) -> Result { + ) -> Result { let mut shape = rnd.shape().to_vec(); let mut nd = NDArray::empty(&mut shape, ctx, dtype); let mut buf = Array::from_iter(rnd.into_iter().map(|&v| v as T)); - nd.copy_from_buffer(buf.as_slice_mut()?); + nd.copy_from_buffer( + buf.as_slice_mut() + .expect("Array from iter must be contiguous."), + ); Ok(nd) } @@ -245,23 +256,25 @@ macro_rules! impl_from_ndarray_rustndarray { ($type:ty, $type_name:tt) => { impl<'a> TryFrom<&'a NDArray> for ArrayD<$type> { type Error = Error; - fn try_from(nd: &NDArray) -> Result> { - if nd.shape().is_none() { - bail!("{}", ErrorKind::EmptyArray); - } + fn try_from(nd: &NDArray) -> Result, Self::Error> { + ensure!(nd.shape().is_some(), errors::MissingShapeError); assert_eq!(nd.dtype(), TVMType::from_str($type_name)?, "Type mismatch"); - Ok(Array::from_shape_vec(&*nd.shape()?, nd.to_vec::<$type>()?)?) + Ok(Array::from_shape_vec( + &*nd.shape().ok_or(errors::MissingShapeError)?, + nd.to_vec::<$type>()?, + )?) } } impl<'a> TryFrom<&'a mut NDArray> for ArrayD<$type> { type Error = Error; - fn try_from(nd: &mut NDArray) -> Result> { - if nd.shape().is_none() { - bail!("{}", ErrorKind::EmptyArray); - } + fn try_from(nd: &mut NDArray) -> Result, Self::Error> { + ensure!(nd.shape().is_some(), errors::MissingShapeError); assert_eq!(nd.dtype(), TVMType::from_str($type_name)?, "Type mismatch"); - Ok(Array::from_shape_vec(&*nd.shape()?, nd.to_vec::<$type>()?)?) + Ok(Array::from_shape_vec( + &*nd.shape().ok_or(errors::MissingShapeError)?, + nd.to_vec::<$type>()?, + )?) } } }; diff --git a/rust/frontend/src/value.rs b/rust/frontend/src/value.rs index b579299d75ae..eb62f10cabec 100644 --- a/rust/frontend/src/value.rs +++ b/rust/frontend/src/value.rs @@ -4,6 +4,7 @@ use std::{convert::TryFrom, os::raw::c_void}; +use failure::Error; use tvm_common::{ ensure_type, ffi::{self, TVMValue}, @@ -42,9 +43,9 @@ macro_rules! impl_tvm_val_from_handle { impl<'a, 'v> TryFrom<&'a TVMArgValue<'v>> for $ty { type Error = Error; - fn try_from(arg: &TVMArgValue<'v>) -> Result<$ty> { + fn try_from(arg: &TVMArgValue<'v>) -> Result<$ty, Self::Error> { ensure_type!(arg, $type_code); - Ok($ty::new(unsafe { *(arg.value.v_handle as *const $handle) })) + Ok($ty::new(unsafe { arg.value.v_handle as $handle })) } } @@ -62,7 +63,7 @@ macro_rules! impl_tvm_val_from_handle { impl TryFrom for $ty { type Error = Error; - fn try_from(ret: TVMRetValue) -> Result<$ty> { + fn try_from(ret: TVMRetValue) -> Result<$ty, Self::Error> { ensure_type!(ret, $type_code); Ok($ty::new(unsafe { ret.value.v_handle as $handle })) } @@ -99,14 +100,11 @@ macro_rules! impl_boxed_ret_value { } impl TryFrom for $type { type Error = Error; - fn try_from(ret: TVMRetValue) -> Result<$type> { + fn try_from(ret: TVMRetValue) -> Result<$type, Self::Error> { if let Ok(val) = ret.box_value.downcast::<$type>() { Ok(*val) } else { - bail!(ErrorKind::TryFromTVMRetValueError( - stringify!($type).to_string(), - ret.type_code - )) + bail!(ValueDowncastError::new($code as i64, ret.type_code as i64)) } } } @@ -118,7 +116,7 @@ impl_boxed_ret_value!(TVMByteArray, ffi::TVMTypeCode_kBytes); impl<'a, 'v> TryFrom<&'a TVMArgValue<'v>> for TVMByteArray { type Error = Error; - fn try_from(arg: &TVMArgValue<'v>) -> Result { + fn try_from(arg: &TVMArgValue<'v>) -> Result { ensure_type!(arg, ffi::TVMTypeCode_kBytes); Ok(TVMByteArray::new(unsafe { *(arg.value.v_handle as *mut ffi::TVMByteArray) diff --git a/rust/frontend/tests/basics/src/main.rs b/rust/frontend/tests/basics/src/main.rs index eedb521dad43..55c537bfd362 100644 --- a/rust/frontend/tests/basics/src/main.rs +++ b/rust/frontend/tests/basics/src/main.rs @@ -1,5 +1,3 @@ -#![feature(try_from)] - extern crate ndarray as rust_ndarray; extern crate tvm_frontend as tvm; diff --git a/rust/frontend/tests/callback/src/bin/array.rs b/rust/frontend/tests/callback/src/bin/array.rs index 52fe27417d70..e77ea435a057 100644 --- a/rust/frontend/tests/callback/src/bin/array.rs +++ b/rust/frontend/tests/callback/src/bin/array.rs @@ -1,4 +1,3 @@ -#![feature(extern_crate_item_prelude, try_from)] #![allow(unused_imports)] extern crate ndarray as rust_ndarray; @@ -11,11 +10,11 @@ use std::{ str::FromStr, }; -use tvm::*; +use tvm::{errors::Error, *}; fn main() { register_global_func! { - fn sum(args: &[TVMArgValue]) -> Result { + fn sum(args: &[TVMArgValue]) -> Result { let mut ret = 0f32; let shape = &mut [2]; for arg in args.iter() { diff --git a/rust/frontend/tests/callback/src/bin/error.rs b/rust/frontend/tests/callback/src/bin/error.rs index 1a6e7efd74da..24a1f07e9764 100644 --- a/rust/frontend/tests/callback/src/bin/error.rs +++ b/rust/frontend/tests/callback/src/bin/error.rs @@ -1,4 +1,4 @@ -#![feature(extern_crate_item_prelude, panic_info_message)] +#![feature(panic_info_message)] #![allow(unused_imports)] use std::panic; @@ -6,15 +6,15 @@ use std::panic; #[macro_use] extern crate tvm_frontend as tvm; -use tvm::*; +use tvm::{errors::Error, *}; fn main() { register_global_func! { - fn error(_args: &[TVMArgValue]) -> Result { - Err(ErrorKind::TypeMismatch( - format!("{}", "i64".to_string()), - format!("{}", "f64".to_string()), - ).into()) + fn error(_args: &[TVMArgValue]) -> Result { + Err(errors::TypeMismatchError{ + expected: "i64".to_string(), + actual: "f64".to_string(), + }.into()) } } diff --git a/rust/frontend/tests/callback/src/bin/float.rs b/rust/frontend/tests/callback/src/bin/float.rs index e5aa23b76378..a26487be8678 100644 --- a/rust/frontend/tests/callback/src/bin/float.rs +++ b/rust/frontend/tests/callback/src/bin/float.rs @@ -1,15 +1,14 @@ -#![feature(extern_crate_item_prelude, try_from)] #![allow(unused_imports)] #[macro_use] extern crate tvm_frontend as tvm; use std::convert::TryInto; -use tvm::*; +use tvm::{errors::Error, *}; fn main() { register_global_func! { - fn sum(args: &[TVMArgValue]) -> Result { + fn sum(args: &[TVMArgValue]) -> Result { let mut ret = 0.0; for arg in args.into_iter() { let val: f64 = arg.try_into()?; diff --git a/rust/frontend/tests/callback/src/bin/int.rs b/rust/frontend/tests/callback/src/bin/int.rs index 724f6082d8be..591f95a660a1 100644 --- a/rust/frontend/tests/callback/src/bin/int.rs +++ b/rust/frontend/tests/callback/src/bin/int.rs @@ -1,13 +1,12 @@ -#![feature(extern_crate_item_prelude, try_from)] #![allow(unused_imports)] extern crate tvm_frontend as tvm; use std::convert::TryInto; -use tvm::*; +use tvm::{errors::Error, *}; fn main() { - fn sum(args: &[TVMArgValue]) -> Result { + fn sum(args: &[TVMArgValue]) -> Result { let mut ret = 0i64; for arg in args.iter() { let val: i64 = arg.try_into()?; diff --git a/rust/frontend/tests/callback/src/bin/string.rs b/rust/frontend/tests/callback/src/bin/string.rs index 346019b1cab5..3b2ad65a2f45 100644 --- a/rust/frontend/tests/callback/src/bin/string.rs +++ b/rust/frontend/tests/callback/src/bin/string.rs @@ -1,15 +1,14 @@ -#![feature(try_from)] #![allow(unused_imports)] #[macro_use] extern crate tvm_frontend as tvm; use std::convert::TryInto; -use tvm::*; +use tvm::{errors::Error, *}; // FIXME fn main() { register_global_func! { - fn concate_str(args: &[TVMArgValue]) -> Result { + fn concate_str(args: &[TVMArgValue]) -> Result { let mut ret = "".to_string(); for arg in args.iter() { let val: &str = arg.try_into()?; From 32343421b8fbea4d4c3d398c314255c515b86152 Mon Sep 17 00:00:00 2001 From: Nick Hynes Date: Sat, 23 Mar 2019 19:58:03 +0000 Subject: [PATCH 15/17] Fix SGX build --- rust/runtime/src/allocator.rs | 8 +++---- rust/runtime/src/errors.rs | 40 +++++------------------------------ rust/runtime/src/lib.rs | 3 --- rust/runtime/src/sgx.rs | 18 +++++++++------- rust/runtime/src/threading.rs | 8 +++---- 5 files changed, 21 insertions(+), 56 deletions(-) diff --git a/rust/runtime/src/allocator.rs b/rust/runtime/src/allocator.rs index fba325012784..0514dce2b0a7 100644 --- a/rust/runtime/src/allocator.rs +++ b/rust/runtime/src/allocator.rs @@ -1,9 +1,7 @@ #[cfg(target_env = "sgx")] -use alloc::alloc::{self, Layout}; +use alloc::alloc::{self, Layout, LayoutErr}; #[cfg(not(target_env = "sgx"))] -use std::alloc::{self, Layout}; - -use failure::Error; +use std::alloc::{self, Layout, LayoutErr}; const DEFAULT_ALIGN_BYTES: usize = 4; @@ -15,7 +13,7 @@ pub struct Allocation { impl Allocation { /// Allocates a chunk of memory of `size` bytes with optional alignment. - pub fn new(size: usize, align: Option) -> Result { + pub fn new(size: usize, align: Option) -> Result { let alignment = align.unwrap_or(DEFAULT_ALIGN_BYTES); let layout = Layout::from_size_align(size, alignment)?; let ptr = unsafe { alloc::alloc(layout.clone()) }; diff --git a/rust/runtime/src/errors.rs b/rust/runtime/src/errors.rs index 5688f3699e8b..26a8961697c3 100644 --- a/rust/runtime/src/errors.rs +++ b/rust/runtime/src/errors.rs @@ -12,38 +12,8 @@ pub enum GraphFormatError { InvalidDLType(String), } -// #[cfg(target_env = "sgx")] -// use alloc::alloc; -// #[cfg(not(target_env = "sgx"))] -// use std::alloc; -// use std::num; -// -// use ndarray; -// use serde_json; - -// error_chain! { -// errors { -// GraphFormatError(msg: String) { -// description("unable to load graph") -// display("could not load graph json: {}", msg) -// } -// -// LoadGraphParamsError(msg: String) { -// description("unable to load graph params") -// display("could not load graph params: {}", msg) -// } -// } -// foreign_links { -// Alloc(alloc::AllocErr); -// GraphDeserialize(serde_json::Error); -// ParseInt(num::ParseIntError); -// ShapeError(ndarray::ShapeError); -// CommonError(tvm_common::errors::Error); -// } -// } -// -// impl From for Error { -// fn from(_err: alloc::LayoutErr) -> Error { -// Error::from_kind(ErrorKind::Msg("Layout error".to_string())) -// } -// } +#[derive(Debug, Fail)] +#[fail(display = "SGX error: 0x{:x}", code)] +pub struct SgxError { + pub code: u32, +} diff --git a/rust/runtime/src/lib.rs b/rust/runtime/src/lib.rs index 5a0faf7a0ac1..848db27ecdcc 100644 --- a/rust/runtime/src/lib.rs +++ b/rust/runtime/src/lib.rs @@ -61,9 +61,6 @@ pub use tvm_common::{ pub use self::{array::*, errors::*, graph::*, module::*, threading::*, workspace::*}; -#[cfg(target_env = "sgx")] -use self::sgx::ocall_packed_func; - lazy_static! { static ref LAST_ERROR: std::sync::RwLock> = std::sync::RwLock::new(None); diff --git a/rust/runtime/src/sgx.rs b/rust/runtime/src/sgx.rs index f8570355db60..42d3aa4aaa7e 100644 --- a/rust/runtime/src/sgx.rs +++ b/rust/runtime/src/sgx.rs @@ -3,18 +3,17 @@ use std::{ os::raw::{c_char, c_int}, }; -use errors::Result; +pub use crate::threading::tvm_run_worker as run_worker; +use crate::{threading::sgx_join_threads, SystemLibModule, TVMArgValue, TVMRetValue}; +use errors::SgxError; use ffi::TVMValue; -use runtime::{threading::sgx_join_threads, SystemLibModule, TVMArgValue, TVMRetValue}; - -pub use runtime::threading::tvm_run_worker as run_worker; #[macro_export] macro_rules! tvm_ocall { ($func: expr) => { match $func { 0 => Ok(()), - err => Err(format_err!("SGX error: {}", err)), + code => Err(SgxError { code }), } }; } @@ -33,7 +32,10 @@ extern "C" { ) -> SgxStatus; } -pub fn ocall_packed_func>(fn_name: S, args: &[TVMArgValue]) -> Result { +pub fn ocall_packed_func>( + fn_name: S, + args: &[TVMArgValue], +) -> Result { let mut ret_val = TVMValue { v_int64: 0 }; let ret_type_code = 0i64; unsafe { @@ -58,11 +60,11 @@ pub fn ocall_packed_func>(fn_name: S, args: &[TVMArgValue]) -> Res #[macro_export] macro_rules! ocall_packed { ($fn_name:expr, $($args:expr),+) => { - ocall_packed_func($fn_name, &[$($args.into(),)+]) + $crate::sgx::ocall_packed_func($fn_name, &[$($args.into(),)+]) .expect(concat!("Error calling `", $fn_name, "`")) }; ($fn_name:expr) => { - ocall_packed_func($fn_name, &Vec::new()) + $crate::sgx::ocall_packed_func($fn_name, &Vec::new()) .expect(concat!("Error calling `", $fn_name, "`")) } } diff --git a/rust/runtime/src/threading.rs b/rust/runtime/src/threading.rs index 3c34a25e0014..408c0b491bb0 100644 --- a/rust/runtime/src/threading.rs +++ b/rust/runtime/src/threading.rs @@ -18,11 +18,10 @@ use std::{ use std::{collections::VecDeque, ptr, sync::Mutex}; use bounded_spsc_queue::{self, Producer}; -use failure::Error; use tvm_common::ffi::TVMParallelGroupEnv; #[cfg(target_env = "sgx")] -use super::{sgx::ocall_packed_func, TVMArgValue, TVMRetValue}; +use super::{TVMArgValue, TVMRetValue}; type FTVMParallelLambda = extern "C" fn(task_id: usize, penv: *const TVMParallelGroupEnv, cdata: *const c_void) -> i32; @@ -62,12 +61,11 @@ impl Job { } /// Waits for all tasks in this `Job` to be completed. - fn wait(&self) -> Result<(), Error> { + fn wait(&self) { while self.pending.load(Ordering::Acquire) > 0 { #[cfg(not(target_env = "sgx"))] thread::yield_now(); } - Ok(()) } } @@ -161,7 +159,7 @@ impl ThreadPool { } tasks.pop().unwrap()(); - job.wait().unwrap(); + job.wait(); } fn run_worker(queue: Consumer) { From 506a92058be4cf8efba46742c5530f00f552559d Mon Sep 17 00:00:00 2001 From: Nick Hynes Date: Sun, 24 Mar 2019 09:03:48 +0000 Subject: [PATCH 16/17] Update resnet example --- rust/common/src/packed_func.rs | 18 ++++++++++++++++-- rust/common/src/value.rs | 2 +- rust/frontend/examples/resnet/src/main.rs | 8 ++------ rust/frontend/src/context.rs | 14 +++++++------- rust/frontend/src/ndarray.rs | 2 +- 5 files changed, 27 insertions(+), 17 deletions(-) diff --git a/rust/common/src/packed_func.rs b/rust/common/src/packed_func.rs index 8ced5cc7e3bf..a564fe656415 100644 --- a/rust/common/src/packed_func.rs +++ b/rust/common/src/packed_func.rs @@ -59,7 +59,7 @@ macro_rules! ensure_type { macro_rules! impl_prim_tvm_arg { ($type_code:ident, $field:ident, $field_type:ty, [ $( $type:ty ),+ ] ) => { $( - impl<'a> From<$type> for TVMArgValue<'a> { + impl From<$type> for TVMArgValue<'static> { fn from(val: $type) -> Self { TVMArgValue { value: TVMValue { $field: val as $field_type }, @@ -112,11 +112,25 @@ impl_prim_tvm_arg!( [u8, u16, u32, u64, usize] ); +#[cfg(feature = "bindings")] +// only allow this in bindings because pure-rust can't take ownership of leaked CString +impl<'a> From<&String> for TVMArgValue<'a> { + fn from(string: &String) -> Self { + TVMArgValue { + value: TVMValue { + v_str: std::ffi::CString::new(string.clone()).unwrap().into_raw(), + }, + type_code: TVMTypeCode_kStr as i64, + _lifetime: PhantomData, + } + } +} + impl<'a> From<&std::ffi::CString> for TVMArgValue<'a> { fn from(string: &std::ffi::CString) -> Self { TVMArgValue { value: TVMValue { - v_handle: string.as_ptr() as *const _ as *mut c_void, + v_str: string.as_ptr(), }, type_code: TVMTypeCode_kStr as i64, _lifetime: PhantomData, diff --git a/rust/common/src/value.rs b/rust/common/src/value.rs index 739a2bf19b73..c7c040b0060e 100644 --- a/rust/common/src/value.rs +++ b/rust/common/src/value.rs @@ -45,7 +45,7 @@ impl FromStr for TVMType { _ => return Err(format_err!("Unknown type {}", type_name)), }; - Ok(dbg!(TVMType::new(type_code, bits, lanes))) + Ok(TVMType::new(type_code, bits, lanes)) } } diff --git a/rust/frontend/examples/resnet/src/main.rs b/rust/frontend/examples/resnet/src/main.rs index 060fad3c164a..cb323399daf6 100644 --- a/rust/frontend/examples/resnet/src/main.rs +++ b/rust/frontend/examples/resnet/src/main.rs @@ -1,5 +1,3 @@ -#![feature(try_from)] - extern crate csv; extern crate image; extern crate ndarray; @@ -55,10 +53,8 @@ fn main() { "input size is {:?}", input.shape().expect("cannot get the input shape") ); - let graph = std::ffi::CString::new( - fs::read_to_string(concat!(env!("CARGO_MANIFEST_DIR"), "/deploy_graph.json")).unwrap(), - ) - .unwrap(); + let graph = + fs::read_to_string(concat!(env!("CARGO_MANIFEST_DIR"), "/deploy_graph.json")).unwrap(); // load the built module let lib = Module::load(&Path::new(concat!( env!("CARGO_MANIFEST_DIR"), diff --git a/rust/frontend/src/context.rs b/rust/frontend/src/context.rs index ff25414c58b3..5d800a8b9644 100644 --- a/rust/frontend/src/context.rs +++ b/rust/frontend/src/context.rs @@ -158,15 +158,15 @@ pub struct TVMContext { /// Supported device types pub device_type: TVMDeviceType, /// Device id - pub device_id: usize, + pub device_id: i32, } impl TVMContext { /// Creates context from device type and id. - pub fn new(device_type: TVMDeviceType, device_id: usize) -> Self { + pub fn new(device_type: TVMDeviceType, device_id: i32) -> Self { TVMContext { - device_type: device_type, - device_id: device_id, + device_type, + device_id, } } } @@ -175,7 +175,7 @@ macro_rules! impl_ctxs { ($(($ctx:ident, $dldevt:expr));+) => { $( impl TVMContext { - pub fn $ctx(device_id: usize) -> Self { + pub fn $ctx(device_id: i32) -> Self { Self::new(TVMDeviceType($dldevt), device_id) } } @@ -238,7 +238,7 @@ macro_rules! impl_device_attrs { // `unwrap` is ok here because if there is any error, // if would occur in function call. function::Builder::from(func) - .args(&[dt, self.device_id, $attr_kind]) + .args(&[dt, self.device_id as usize, $attr_kind]) .invoke() .unwrap() .try_into() @@ -262,7 +262,7 @@ impl From for TVMContext { fn from(ctx: ffi::DLContext) -> Self { TVMContext { device_type: TVMDeviceType::from(ctx.device_type), - device_id: ctx.device_id as usize, + device_id: ctx.device_id, } } } diff --git a/rust/frontend/src/ndarray.rs b/rust/frontend/src/ndarray.rs index d676c61bfde5..1939c92c0f0b 100644 --- a/rust/frontend/src/ndarray.rs +++ b/rust/frontend/src/ndarray.rs @@ -188,7 +188,7 @@ impl NDArray { /// Copies the NDArray to another target NDArray. pub fn copy_to_ndarray(&self, target: NDArray) -> Result { - if dbg!(self.dtype()) != dbg!(target.dtype()) { + if self.dtype() != target.dtype() { bail!( "{}", errors::TypeMismatchError { From 9a02eae4c0bedd5efb461e625514771826d1aadc Mon Sep 17 00:00:00 2001 From: Nick Hynes Date: Fri, 29 Mar 2019 20:32:48 +0000 Subject: [PATCH 17/17] Address review comments --- rust/frontend/examples/resnet/src/main.rs | 2 +- rust/frontend/src/function.rs | 3 ++- rust/frontend/src/lib.rs | 3 +-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/rust/frontend/examples/resnet/src/main.rs b/rust/frontend/examples/resnet/src/main.rs index cb323399daf6..2ad3efa9082a 100644 --- a/rust/frontend/examples/resnet/src/main.rs +++ b/rust/frontend/examples/resnet/src/main.rs @@ -88,7 +88,7 @@ fn main() { .get_function("set_input", false) .unwrap(); - let data_str = std::ffi::CString::new("data").unwrap(); + let data_str = "data".to_string(); call_packed!(set_input_fn, &data_str, &input).unwrap(); // get `run` function from runtime module let ref run_fn = graph_runtime_module.get_function("run", false).unwrap(); diff --git a/rust/frontend/src/function.rs b/rust/frontend/src/function.rs index 004796fb48ee..f0fbcbe67e25 100644 --- a/rust/frontend/src/function.rs +++ b/rust/frontend/src/function.rs @@ -188,6 +188,7 @@ impl<'a, 'm> Builder<'a, 'm> { /// Calls the function that created from `Builder`. pub fn invoke(&mut self) -> Result { + #![allow(unused_unsafe)] ensure!(self.func.is_some(), errors::FunctionNotFoundError); let num_args = self.arg_buf.len(); @@ -235,7 +236,7 @@ unsafe extern "C" fn tvm_callback( fhandle: *mut c_void, ) -> c_int { // turning off the incorrect linter complaints - #![allow(unused_assignments)] + #![allow(unused_assignments, unused_unsafe)] let len = num_args as usize; let args_list = slice::from_raw_parts_mut(args, len); let type_codes_list = slice::from_raw_parts_mut(type_codes, len); diff --git a/rust/frontend/src/lib.rs b/rust/frontend/src/lib.rs index cd1561090144..a773b2735d9c 100644 --- a/rust/frontend/src/lib.rs +++ b/rust/frontend/src/lib.rs @@ -11,8 +11,7 @@ //! //! Checkout the `examples` repository for more details. -#![allow(non_camel_case_types, unused_unsafe)] -#![feature(try_trait, fn_traits, unboxed_closures, box_syntax)] +#![feature(box_syntax)] #[macro_use] extern crate failure;