From db6eb0f7dcbd0b28cb63f5b9dd3f6acc88e92eca Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Wed, 11 Aug 2021 10:22:30 -0700 Subject: [PATCH 01/13] Add C++ API for computing type key from type index --- include/tvm/runtime/c_runtime_api.h | 8 ++++++++ src/runtime/object.cc | 8 ++++++++ 2 files changed, 16 insertions(+) diff --git a/include/tvm/runtime/c_runtime_api.h b/include/tvm/runtime/c_runtime_api.h index 17d1ba2a5132..1039590b34a8 100644 --- a/include/tvm/runtime/c_runtime_api.h +++ b/include/tvm/runtime/c_runtime_api.h @@ -520,6 +520,14 @@ TVM_DLL int TVMObjectGetTypeIndex(TVMObjectHandle obj, unsigned* out_tindex); */ TVM_DLL int TVMObjectTypeKey2Index(const char* type_key, unsigned* out_tindex); +/*! + * \brief Convert type key to type index. + * \param type_key The key of the type. + * \param out_tindex the corresponding type index. + * \return 0 when success, nonzero when failure happens + */ +TVM_DLL int TVMObjectTypeIndex2Key(unsigned tindex, char** out_type_key); + /*! * \brief Increase the reference count of an object. * diff --git a/src/runtime/object.cc b/src/runtime/object.cc index 1892ce780a4c..f89d805da136 100644 --- a/src/runtime/object.cc +++ b/src/runtime/object.cc @@ -262,3 +262,11 @@ int TVMObjectTypeKey2Index(const char* type_key, unsigned* out_tindex) { out_tindex[0] = tvm::runtime::ObjectInternal::ObjectTypeKey2Index(type_key); API_END(); } + +int TVMObjectTypeIndex2Key(unsigned tindex, char** out_type_key) { + API_BEGIN(); + auto key = tvm::runtime::Object::TypeIndex2Key(tindex); + *out_type_key = static_cast(malloc(key.size() + 1)); + strncpy(*out_type_key, key.c_str(), key.size()); + API_END(); +} From 3a131b74960cc9aa7d2b36091432854c72f5f8b6 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Wed, 11 Aug 2021 10:28:51 -0700 Subject: [PATCH 02/13] Try and isolate leak --- rust/tvm-rt/src/function.rs | 18 ++++++++++++++++++ rust/tvm/examples/resnet/src/main.rs | 27 +++++++++++++++------------ 2 files changed, 33 insertions(+), 12 deletions(-) diff --git a/rust/tvm-rt/src/function.rs b/rust/tvm-rt/src/function.rs index 5db665cc7a48..1152f3f235b8 100644 --- a/rust/tvm-rt/src/function.rs +++ b/rust/tvm-rt/src/function.rs @@ -141,6 +141,24 @@ impl Function { let rv = RetValue::from_tvm_value(ret_val, ret_type_code as u32); + // // This is a temporary patch to ensure that the arguments are correclty dropped. + // let args: Vec = values.into_iter().zip(type_codes.into_iter()).map(|(value, type_code)| { + // ArgValue::from_tvm_value(value, type_code) + // }).collect(); + + // let mut objects_to_drop: Vec = vec![]; + // for arg in args { + // match arg { + // ArgValue::ObjectHandle(_) | ArgValue::ModuleHandle(_) | ArgValue::NDArrayHandle(_) => objects_to_drop.push(arg.try_into().unwrap()), + // _ => {} + // } + // } + + // drop(objects_to_drop); + + let obj: crate::ObjectRef = rv.clone().try_into().unwrap(); + println!("rv: {}", obj.count()); + Ok(rv) } } diff --git a/rust/tvm/examples/resnet/src/main.rs b/rust/tvm/examples/resnet/src/main.rs index bd0de1c56ba3..96d74e2260a4 100644 --- a/rust/tvm/examples/resnet/src/main.rs +++ b/rust/tvm/examples/resnet/src/main.rs @@ -78,24 +78,27 @@ fn main() -> anyhow::Result<()> { "/deploy_lib.so" )))?; - let mut graph_rt = GraphRt::create_from_parts(&graph, lib, dev)?; - // parse parameters and convert to TVMByteArray let params: Vec = fs::read(concat!(env!("CARGO_MANIFEST_DIR"), "/deploy_param.params"))?; - println!("param bytes: {}", params.len()); - graph_rt.load_params(¶ms)?; - graph_rt.set_input("data", input)?; - graph_rt.run()?; + let mut output: Vec; + + loop { + let mut graph_rt = GraphRt::create_from_parts(&graph, lib.clone(), dev)?; - // prepare to get the output - let output_shape = &[1, 1000]; - let output = NDArray::empty(output_shape, Device::cpu(0), DataType::float(32, 1)); - graph_rt.get_output_into(0, output.clone())?; + graph_rt.load_params(¶ms)?; + graph_rt.set_input("data", input.clone())?; + graph_rt.run()?; - // flatten the output as Vec - let output = output.to_vec::()?; + // prepare to get the output + let output_shape = &[1, 1000]; + let output_nd = NDArray::empty(output_shape, Device::cpu(0), DataType::float(32, 1)); + graph_rt.get_output_into(0, output_nd.clone())?; + + // flatten the output as Vec + output = output_nd.to_vec::()?; + } // find the maximum entry in the output and its index let (argmax, max_prob) = output From 8226ca371150620f0cb97e3b85cd50004e799e00 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Wed, 11 Aug 2021 23:35:35 -0700 Subject: [PATCH 03/13] Rewrite the bindings to fix the ArgValue lifetime issue There are still quite a few issues left to resolve in this patch, but I believe the runtime changes stablize memory consumption as long as the parameters are only set once. ByteArray also has some totally broken unsafe code which I am unsure of how it was introduced. --- rust/tvm-macros/src/object.rs | 13 +---- rust/tvm-rt/src/array.rs | 13 +++-- rust/tvm-rt/src/function.rs | 54 ++++++----------- rust/tvm-rt/src/graph_rt.rs | 6 +- rust/tvm-rt/src/map.rs | 14 ++--- rust/tvm-rt/src/ndarray.rs | 15 +++++ rust/tvm-rt/src/object/mod.rs | 16 +++++- rust/tvm-rt/src/object/object_ptr.rs | 20 ++++--- rust/tvm-rt/src/to_function.rs | 86 +++++++++++++++++----------- rust/tvm-sys/src/byte_array.rs | 22 +++---- rust/tvm-sys/src/packed_func.rs | 16 +++--- rust/tvm/examples/resnet/src/main.rs | 5 +- 12 files changed, 149 insertions(+), 131 deletions(-) diff --git a/rust/tvm-macros/src/object.rs b/rust/tvm-macros/src/object.rs index c84d0aab612f..4134da5fe6d9 100644 --- a/rust/tvm-macros/src/object.rs +++ b/rust/tvm-macros/src/object.rs @@ -147,8 +147,8 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream { } } - impl<'a> From<#ref_id> for #tvm_rt_crate::ArgValue<'a> { - fn from(object_ref: #ref_id) -> #tvm_rt_crate::ArgValue<'a> { + impl<'a> From<&'a #ref_id> for #tvm_rt_crate::ArgValue<'a> { + fn from(object_ref: &'a #ref_id) -> #tvm_rt_crate::ArgValue<'a> { use std::ffi::c_void; let object_ptr = &object_ref.0; match object_ptr { @@ -156,18 +156,11 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream { #tvm_rt_crate::ArgValue:: ObjectHandle(std::ptr::null::() as *mut c_void) } - Some(value) => value.clone().into() + Some(value) => value.into() } } } - impl<'a> From<&#ref_id> for #tvm_rt_crate::ArgValue<'a> { - fn from(object_ref: &#ref_id) -> #tvm_rt_crate::ArgValue<'a> { - let oref: #ref_id = object_ref.clone(); - #tvm_rt_crate::ArgValue::<'a>::from(oref) - } - } - impl<'a> std::convert::TryFrom<#tvm_rt_crate::ArgValue<'a>> for #ref_id { type Error = #error; diff --git a/rust/tvm-rt/src/array.rs b/rust/tvm-rt/src/array.rs index e8902b54f6ef..02c34a1d133f 100644 --- a/rust/tvm-rt/src/array.rs +++ b/rust/tvm-rt/src/array.rs @@ -45,19 +45,22 @@ external! { fn array_size(array: ObjectRef) -> i64; } -impl IsObjectRef for Array { +impl IsObjectRef for Array { type Object = Object; fn as_ptr(&self) -> Option<&ObjectPtr> { self.object.as_ptr() } + fn into_ptr(self) -> Option> { self.object.into_ptr() } + fn from_ptr(object_ptr: Option>) -> Self { let object_ref = match object_ptr { Some(o) => o.into(), _ => panic!(), }; + Array { object: object_ref, _data: PhantomData, @@ -67,7 +70,7 @@ impl IsObjectRef for Array { impl Array { pub fn from_vec(data: Vec) -> Result> { - let iter = data.into_iter().map(T::into_arg_value).collect(); + let iter = data.iter().map(T::into_arg_value).collect(); let func = Function::get("runtime.Array").expect( "runtime.Array function is not registered, this is most likely a build or linking error", @@ -151,9 +154,9 @@ impl FromIterator for Array { } } -impl<'a, T: IsObjectRef> From> for ArgValue<'a> { - fn from(array: Array) -> ArgValue<'a> { - array.object.into() +impl<'a, T: IsObjectRef> From<&'a Array> for ArgValue<'a> { + fn from(array: &'a Array) -> ArgValue<'a> { + (&array.object).into() } } diff --git a/rust/tvm-rt/src/function.rs b/rust/tvm-rt/src/function.rs index 1152f3f235b8..3a32c7ed9409 100644 --- a/rust/tvm-rt/src/function.rs +++ b/rust/tvm-rt/src/function.rs @@ -37,6 +37,7 @@ use crate::errors::Error; pub use super::to_function::{ToFunction, Typed}; pub use tvm_sys::{ffi, ArgValue, RetValue}; +use crate::object::AsArgValue; pub type Result = std::result::Result; @@ -141,24 +142,6 @@ impl Function { let rv = RetValue::from_tvm_value(ret_val, ret_type_code as u32); - // // This is a temporary patch to ensure that the arguments are correclty dropped. - // let args: Vec = values.into_iter().zip(type_codes.into_iter()).map(|(value, type_code)| { - // ArgValue::from_tvm_value(value, type_code) - // }).collect(); - - // let mut objects_to_drop: Vec = vec![]; - // for arg in args { - // match arg { - // ArgValue::ObjectHandle(_) | ArgValue::ModuleHandle(_) | ArgValue::NDArrayHandle(_) => objects_to_drop.push(arg.try_into().unwrap()), - // _ => {} - // } - // } - - // drop(objects_to_drop); - - let obj: crate::ObjectRef = rv.clone().try_into().unwrap(); - println!("rv: {}", obj.count()); - Ok(rv) } } @@ -171,12 +154,12 @@ macro_rules! impl_to_fn { where Error: From, Out: TryFrom, - $($t: Into>),* + $($t: for<'a> AsArgValue<'a>),* { fn from(func: Function) -> Self { #[allow(non_snake_case)] Box::new(move |$($t : $t),*| { - let args = vec![ $($t.into()),* ]; + let args = vec![ $((&$t).as_arg_value()),* ]; Ok(func.invoke(args)?.try_into()?) }) } @@ -281,7 +264,7 @@ impl<'a> TryFrom<&ArgValue<'a>> for Function { pub fn register>(f: F, name: S) -> Result<()> where F: ToFunction, - F: Typed, + F: for<'a> Typed<'a, I, O>, { register_override(f, name, false) } @@ -292,7 +275,7 @@ where pub fn register_override>(f: F, name: S, override_: bool) -> Result<()> where F: ToFunction, - F: Typed, + F: for<'a> Typed<'a, I, O>, { let func = f.to_function(); let name = name.into(); @@ -309,22 +292,23 @@ where } pub fn register_untyped>( - f: fn(Vec>) -> Result, + f: for<'a> fn(Vec>) -> Result, name: S, override_: bool, ) -> Result<()> { - // TODO(@jroesch): can we unify all the code. - let func = f.to_function(); - let name = name.into(); - // Not sure about this code - let handle = func.handle(); - let name = CString::new(name)?; - check_call!(ffi::TVMFuncRegisterGlobal( - name.into_raw(), - handle, - override_ as c_int - )); - Ok(()) + panic!("foo") + // // TODO(@jroesch): can we unify all the code. + // let func = ToFunction::, RetValue>::to_function(f); + // let name = name.into(); + // // Not sure about this code + // let handle = func.handle(); + // let name = CString::new(name)?; + // check_call!(ffi::TVMFuncRegisterGlobal( + // name.into_raw(), + // handle, + // override_ as c_int + // )); + // Ok(()) } #[cfg(test)] diff --git a/rust/tvm-rt/src/graph_rt.rs b/rust/tvm-rt/src/graph_rt.rs index 7db53d466665..5ac9710424e0 100644 --- a/rust/tvm-rt/src/graph_rt.rs +++ b/rust/tvm-rt/src/graph_rt.rs @@ -50,7 +50,7 @@ impl GraphRt { let runtime_create_fn_ret = runtime_create_fn.invoke(vec![ graph.into(), - lib.into(), + (&lib).into(), (&dev.device_type).into(), // NOTE you must pass the device id in as i32 because that's what TVM expects (dev.device_id as i32).into(), @@ -79,7 +79,7 @@ impl GraphRt { pub fn set_input(&mut self, name: &str, input: NDArray) -> Result<()> { let ref set_input_fn = self.module.get_function("set_input", false)?; - set_input_fn.invoke(vec![name.into(), input.into()])?; + set_input_fn.invoke(vec![name.into(), (&input).into()])?; Ok(()) } @@ -101,7 +101,7 @@ impl GraphRt { /// Extract the ith output from the graph executor and write the results into output. pub fn get_output_into(&mut self, i: i64, output: NDArray) -> Result<()> { let get_output_fn = self.module.get_function("get_output", false)?; - get_output_fn.invoke(vec![i.into(), output.into()])?; + get_output_fn.invoke(vec![i.into(), (&output).into()])?; Ok(()) } } diff --git a/rust/tvm-rt/src/map.rs b/rust/tvm-rt/src/map.rs index d6dfaf3641b8..dbfac6f205b3 100644 --- a/rust/tvm-rt/src/map.rs +++ b/rust/tvm-rt/src/map.rs @@ -58,18 +58,18 @@ external! { fn map_items(map: ObjectRef) -> Array; } -impl FromIterator<(K, V)> for Map +impl<'a, K: 'a, V: 'a> FromIterator<(&'a K, &'a V)> for Map where K: IsObjectRef, V: IsObjectRef, { - fn from_iter>(iter: T) -> Self { + fn from_iter>(iter: T) -> Self { let iter = iter.into_iter(); let (lower_bound, upper_bound) = iter.size_hint(); let mut buffer: Vec = Vec::with_capacity(upper_bound.unwrap_or(lower_bound) * 2); for (k, v) in iter { - buffer.push(k.into()); - buffer.push(v.into()) + buffer.push(k.into_arg_value()); + buffer.push(v.into_arg_value()); } Self::from_data(buffer).expect("failed to convert from data") } @@ -202,13 +202,13 @@ where } } -impl<'a, K, V> From> for ArgValue<'a> +impl<'a, K, V> From<&'a Map> for ArgValue<'a> where K: IsObjectRef, V: IsObjectRef, { - fn from(map: Map) -> ArgValue<'a> { - map.object.into() + fn from(map: &'a Map) -> ArgValue<'a> { + (&map.object).into() } } diff --git a/rust/tvm-rt/src/ndarray.rs b/rust/tvm-rt/src/ndarray.rs index 08dcfe33f28f..80f8f184140c 100644 --- a/rust/tvm-rt/src/ndarray.rs +++ b/rust/tvm-rt/src/ndarray.rs @@ -101,6 +101,21 @@ impl NDArrayContainer { .cast::() } } + + pub fn as_mut_ptr<'a>(object_ptr: &ObjectPtr) -> *mut NDArrayContainer + where + NDArrayContainer: 'a, + { + let base_offset = memoffset::offset_of!(NDArrayContainer, dl_tensor) as isize; + unsafe { + object_ptr + .ptr + .as_ptr() + .cast::() + .offset(base_offset) + .cast::() + } + } } fn cow_usize<'a>(slice: &[i64]) -> Cow<'a, [usize]> { diff --git a/rust/tvm-rt/src/object/mod.rs b/rust/tvm-rt/src/object/mod.rs index 8c07ed9f0853..075ef46f35e2 100644 --- a/rust/tvm-rt/src/object/mod.rs +++ b/rust/tvm-rt/src/object/mod.rs @@ -29,6 +29,16 @@ mod object_ptr; pub use object_ptr::{IsObject, Object, ObjectPtr, ObjectRef}; +pub trait AsArgValue<'a> { + fn as_arg_value(&'a self) -> ArgValue<'a>; +} + +impl<'a, T: 'static> AsArgValue<'a> for T where &'a T: Into> { + fn as_arg_value(&'a self) -> ArgValue<'a> { + self.into() + } +} + // TODO we would prefer to blanket impl From/TryFrom ArgValue/RetValue, but we // can't because of coherence rules. Instead, we generate them in the macro, and // add what we can (including Into instead of From) as subtraits. @@ -37,8 +47,8 @@ pub trait IsObjectRef: Sized + Clone + Into + + for<'a> AsArgValue<'a> + TryFrom - + for<'a> Into> + for<'a> TryFrom, Error = Error> + std::fmt::Debug { @@ -51,8 +61,8 @@ pub trait IsObjectRef: Self::from_ptr(None) } - fn into_arg_value<'a>(self) -> ArgValue<'a> { - self.into() + fn into_arg_value<'a>(&'a self) -> ArgValue<'a> { + self.as_arg_value() } fn from_arg_value<'a>(arg_value: ArgValue<'a>) -> Result { diff --git a/rust/tvm-rt/src/object/object_ptr.rs b/rust/tvm-rt/src/object/object_ptr.rs index a093cf5fe3ae..38dc99a88514 100644 --- a/rust/tvm-rt/src/object/object_ptr.rs +++ b/rust/tvm-rt/src/object/object_ptr.rs @@ -259,6 +259,10 @@ impl ObjectPtr { pub unsafe fn into_raw(self) -> *mut T { self.ptr.as_ptr() } + + pub unsafe fn as_ptr(&self) -> *mut T { + self.ptr.as_ptr() + } } impl std::ops::Deref for ObjectPtr { @@ -308,26 +312,26 @@ impl<'a, T: IsObject> TryFrom for ObjectPtr { } } -impl<'a, T: IsObject> From> for ArgValue<'a> { - fn from(object_ptr: ObjectPtr) -> ArgValue<'a> { +impl<'a, T: IsObject> From<&'a ObjectPtr> for ArgValue<'a> { + fn from(object_ptr: &'a ObjectPtr) -> ArgValue<'a> { debug_assert!(object_ptr.count() >= 1); - let object_ptr = object_ptr.upcast::(); + let object_ptr = object_ptr.clone().upcast::(); match T::TYPE_KEY { "runtime.NDArray" => { use crate::ndarray::NDArrayContainer; - // TODO(this is probably not optimal) - let raw_ptr = NDArrayContainer::leak(object_ptr.downcast().unwrap()) - as *mut NDArrayContainer as *mut std::ffi::c_void; + let dcast_ptr = object_ptr.downcast().unwrap(); + let raw_ptr = NDArrayContainer::as_mut_ptr(&dcast_ptr) + as *mut std::ffi::c_void; assert!(!raw_ptr.is_null()); ArgValue::NDArrayHandle(raw_ptr) } "runtime.Module" => { - let raw_ptr = ObjectPtr::leak(object_ptr) as *mut Object as *mut std::ffi::c_void; + let raw_ptr = unsafe { object_ptr.as_ptr() } as *mut std::ffi::c_void; assert!(!raw_ptr.is_null()); ArgValue::ModuleHandle(raw_ptr) } _ => { - let raw_ptr = ObjectPtr::leak(object_ptr) as *mut Object as *mut std::ffi::c_void; + let raw_ptr = unsafe { object_ptr.as_ptr() } as *mut std::ffi::c_void; assert!(!raw_ptr.is_null()); ArgValue::ObjectHandle(raw_ptr) } diff --git a/rust/tvm-rt/src/to_function.rs b/rust/tvm-rt/src/to_function.rs index 7797d2cd23ff..0a053b7af539 100644 --- a/rust/tvm-rt/src/to_function.rs +++ b/rust/tvm-rt/src/to_function.rs @@ -32,6 +32,7 @@ use std::{ }; use super::{function::Result, Function}; +use crate::AsArgValue; use crate::errors::Error; pub use tvm_sys::{ffi, ArgValue, RetValue}; @@ -44,25 +45,39 @@ pub use tvm_sys::{ffi, ArgValue, RetValue}; /// conversion of inputs and outputs to this trait. /// /// And the implementation of it to `ToFunction`. -pub trait Typed { - fn args(i: Vec>) -> Result; +pub trait Typed<'a, I, O> { + fn args(i: Vec>) -> Result; fn ret(o: O) -> Result; } +trait AsArgValueErased where Self: for<'a> AsArgValue<'a> { + fn as_arg_value<'a>(&'a self) -> ArgValue<'a>; +} + +struct ArgList { + args: Vec> +} + +impl AsArgValueErased for T where T: for<'a> AsArgValue<'a> { + fn as_arg_value<'a>(&'a self) -> ArgValue<'a> { + AsArgValue::as_arg_value(self) + } +} + pub trait ToFunction: Sized { type Handle; fn into_raw(self) -> *mut Self::Handle; - fn call(handle: *mut Self::Handle, args: Vec>) -> Result + fn call<'a>(handle: *mut Self::Handle, args: Vec>) -> Result where - Self: Typed; + Self: for<'arg> Typed<'arg, I, O>; fn drop(handle: *mut Self::Handle); fn to_function(self) -> Function where - Self: Typed, + Self: for<'a> Typed<'a, I, O>, { let mut fhandle = ptr::null_mut() as ffi::TVMFunctionHandle; let resource_handle = self.into_raw(); @@ -87,7 +102,7 @@ pub trait ToFunction: Sized { resource_handle: *mut c_void, ) -> c_int where - Self: Typed, + Self: for <'a> Typed<'a, I, O>, { #![allow(unused_assignments, unused_unsafe)] let result = std::panic::catch_unwind(|| { @@ -165,45 +180,48 @@ pub trait ToFunction: Sized { } } -impl Typed>, RetValue> for fn(Vec>) -> Result { - fn args(args: Vec>) -> Result>> { - Ok(args) - } +// impl<'a> Typed<'a, Vec>, RetValue> for for<'arg> fn(Vec>) -> Result { +// fn args(args: Vec>) -> Result>> { +// Ok(args) +// } - fn ret(o: RetValue) -> Result { - Ok(o) - } -} +// fn ret(o: RetValue) -> Result { +// Ok(o) +// } +// } -impl ToFunction>, RetValue> - for fn(Vec>) -> Result -{ - type Handle = fn(Vec>) -> Result; +// impl ToFunction +// for fn(ArgList) -> Result +// { +// type Handle = for<'a> fn(Vec>) -> Result; - fn into_raw(self) -> *mut Self::Handle { - let ptr: Box = Box::new(self); - Box::into_raw(ptr) - } +// fn into_raw(self) -> *mut Self::Handle { +// let ptr: Box = Box::new(self); +// Box::into_raw(ptr) +// } - fn call(handle: *mut Self::Handle, args: Vec>) -> Result { - unsafe { (*handle)(args) } - } +// fn call<'a>(handle: *mut Self::Handle, args: Vec>) -> Result { +// unsafe { +// let func = (*handle); +// func(args) +// } +// } - fn drop(_: *mut Self::Handle) {} -} +// fn drop(_: *mut Self::Handle) {} +// } macro_rules! impl_typed_and_to_function { ($len:literal; $($t:ident),*) => { - impl Typed<($($t,)*), Out> for F + impl<'a, F, Out, $($t),*> Typed<'a, ($($t,)*), Out> for F where F: Fn($($t),*) -> Out, Out: TryInto, Error: From, - $( $t: TryFrom>, - Error: From<$t::Error>, )* + $( $t: TryFrom>, + Error: From<<$t as TryFrom>>::Error>, )* { #[allow(non_snake_case, unused_variables, unused_mut)] - fn args(args: Vec>) -> Result<($($t,)*)> { + fn args(args: Vec>) -> Result<($($t,)*)> { if args.len() != $len { return Err(Error::CallFailed(format!("{} expected {} arguments, got {}.\n", std::any::type_name::(), @@ -232,9 +250,9 @@ macro_rules! impl_typed_and_to_function { } #[allow(non_snake_case)] - fn call(handle: *mut Self::Handle, args: Vec>) -> Result + fn call<'a>(handle: *mut Self::Handle, args: Vec>) -> Result where - F: Typed<($($t,)*), Out> + F: for<'arg> Typed<'arg, ($($t,)*), Out> { let ($($t,)*) = F::args(args)?; let out = unsafe { (*handle)($($t),*) }; @@ -261,7 +279,7 @@ impl_typed_and_to_function!(6; A, B, C, D, E, G); mod tests { use super::*; - fn call(f: F, args: Vec>) -> Result + fn call<'a, F, I, O>(f: F, args: Vec>) -> Result where F: ToFunction, F: Typed, diff --git a/rust/tvm-sys/src/byte_array.rs b/rust/tvm-sys/src/byte_array.rs index 4b005abee7ef..7da6145797f7 100644 --- a/rust/tvm-sys/src/byte_array.rs +++ b/rust/tvm-sys/src/byte_array.rs @@ -20,7 +20,7 @@ use std::convert::TryFrom; use std::os::raw::c_char; use crate::errors::ValueDowncastError; -use crate::ffi::TVMByteArray; +use crate::ffi::{TVMByteArray, TVMByteArrayFree}; use crate::{ArgValue, RetValue}; /// A newtype wrapping a raw TVM byte-array. @@ -38,6 +38,11 @@ pub struct ByteArray { array: TVMByteArray, } +impl Drop for ByteArray { + fn drop(&mut self) { + } +} + impl ByteArray { /// Gets the underlying byte-array pub fn data(&self) -> &'static [u8] { @@ -59,6 +64,7 @@ impl ByteArray { } } + // Needs AsRef for Vec impl> From for ByteArray { fn from(arg: T) -> Self { @@ -78,20 +84,6 @@ impl<'a> From<&'a ByteArray> for ArgValue<'a> { } } -impl TryFrom> for ByteArray { - type Error = ValueDowncastError; - - fn try_from(val: ArgValue<'static>) -> Result { - match val { - ArgValue::Bytes(array) => Ok(ByteArray { array: *array }), - _ => Err(ValueDowncastError { - expected_type: "ByteArray", - actual_type: format!("{:?}", val), - }), - } - } -} - impl From for RetValue { fn from(val: ByteArray) -> RetValue { RetValue::Bytes(val.array) diff --git a/rust/tvm-sys/src/packed_func.rs b/rust/tvm-sys/src/packed_func.rs index 6f43b786780a..e996b9ddf3b7 100644 --- a/rust/tvm-sys/src/packed_func.rs +++ b/rust/tvm-sys/src/packed_func.rs @@ -284,9 +284,9 @@ impl<'a> From<&'a CStr> for ArgValue<'a> { } } -impl<'a> From for ArgValue<'a> { - fn from(s: CString) -> Self { - Self::String(s.into_raw()) +impl<'a> From<&'a CString> for ArgValue<'a> { + fn from(s: &'a CString) -> Self { + Self::String(s.as_ptr() as _) } } @@ -311,14 +311,14 @@ impl<'a, 'v> TryFrom<&'a ArgValue<'v>> for &'v str { } /// Converts an unspecialized handle to a ArgValue. -impl From<*const T> for ArgValue<'static> { +impl<'a, T> From<*const T> for ArgValue<'a> { fn from(ptr: *const T) -> Self { Self::Handle(ptr as *mut c_void) } } /// Converts an unspecialized mutable handle to a ArgValue. -impl From<*mut T> for ArgValue<'static> { +impl<'a, T> From<*mut T> for ArgValue<'a> { fn from(ptr: *mut T) -> Self { Self::Handle(ptr as *mut c_void) } @@ -382,9 +382,9 @@ impl TryFrom for std::ffi::CString { // Implementations for bool. -impl<'a> From for ArgValue<'a> { - fn from(s: bool) -> Self { - (s as i64).into() +impl<'a> From<&bool> for ArgValue<'a> { + fn from(s: &bool) -> Self { + (*s as i64).into() } } diff --git a/rust/tvm/examples/resnet/src/main.rs b/rust/tvm/examples/resnet/src/main.rs index 96d74e2260a4..22933e0cc5af 100644 --- a/rust/tvm/examples/resnet/src/main.rs +++ b/rust/tvm/examples/resnet/src/main.rs @@ -83,11 +83,10 @@ fn main() -> anyhow::Result<()> { println!("param bytes: {}", params.len()); let mut output: Vec; + let mut graph_rt = GraphRt::create_from_parts(&graph, lib.clone(), dev)?; + graph_rt.load_params(¶ms)?; loop { - let mut graph_rt = GraphRt::create_from_parts(&graph, lib.clone(), dev)?; - - graph_rt.load_params(¶ms)?; graph_rt.set_input("data", input.clone())?; graph_rt.run()?; From 1e4edfe4a600d1575c1bb396c07353429319bf4a Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Fri, 13 Aug 2021 16:03:06 -0700 Subject: [PATCH 04/13] Finish handling tvm-rt issues due to ArgValue lifetime This patch further refactors the bindings to better handle the lifetime issues introduced by detecting the argument memory leak. --- rust/tvm-rt/src/function.rs | 29 ++++--- rust/tvm-rt/src/map.rs | 2 +- rust/tvm-rt/src/object/object_ptr.rs | 34 +++++--- rust/tvm-rt/src/to_function.rs | 120 +++++++++++++++------------ rust/tvm-sys/src/packed_func.rs | 2 +- 5 files changed, 103 insertions(+), 84 deletions(-) diff --git a/rust/tvm-rt/src/function.rs b/rust/tvm-rt/src/function.rs index 3a32c7ed9409..91c90de6686d 100644 --- a/rust/tvm-rt/src/function.rs +++ b/rust/tvm-rt/src/function.rs @@ -35,7 +35,7 @@ use std::{ use crate::errors::Error; -pub use super::to_function::{ToFunction, Typed}; +pub use super::to_function::{ToFunction, Typed, RawArgs}; pub use tvm_sys::{ffi, ArgValue, RetValue}; use crate::object::AsArgValue; @@ -264,7 +264,7 @@ impl<'a> TryFrom<&ArgValue<'a>> for Function { pub fn register>(f: F, name: S) -> Result<()> where F: ToFunction, - F: for<'a> Typed<'a, I, O>, + F: Typed, { register_override(f, name, false) } @@ -275,7 +275,7 @@ where pub fn register_override>(f: F, name: S, override_: bool) -> Result<()> where F: ToFunction, - F: for<'a> Typed<'a, I, O>, + F: Typed, { let func = f.to_function(); let name = name.into(); @@ -296,19 +296,18 @@ pub fn register_untyped>( name: S, override_: bool, ) -> Result<()> { - panic!("foo") // // TODO(@jroesch): can we unify all the code. - // let func = ToFunction::, RetValue>::to_function(f); - // let name = name.into(); - // // Not sure about this code - // let handle = func.handle(); - // let name = CString::new(name)?; - // check_call!(ffi::TVMFuncRegisterGlobal( - // name.into_raw(), - // handle, - // override_ as c_int - // )); - // Ok(()) + let func = ToFunction::::to_function(f); + let name = name.into(); + // Not sure about this code + let handle = func.handle(); + let name = CString::new(name)?; + check_call!(ffi::TVMFuncRegisterGlobal( + name.into_raw(), + handle, + override_ as c_int + )); + Ok(()) } #[cfg(test)] diff --git a/rust/tvm-rt/src/map.rs b/rust/tvm-rt/src/map.rs index dbfac6f205b3..5594a91dc0f0 100644 --- a/rust/tvm-rt/src/map.rs +++ b/rust/tvm-rt/src/map.rs @@ -268,7 +268,7 @@ mod test { let mut std_map: HashMap = HashMap::new(); std_map.insert("key1".into(), "value1".into()); std_map.insert("key2".into(), "value2".into()); - let tvm_map = Map::from_iter(std_map.clone().into_iter()); + let tvm_map = Map::from_iter(std_map.iter()); let back_map = tvm_map.into(); assert_eq!(std_map, back_map); } diff --git a/rust/tvm-rt/src/object/object_ptr.rs b/rust/tvm-rt/src/object/object_ptr.rs index 38dc99a88514..59a4347bfe80 100644 --- a/rust/tvm-rt/src/object/object_ptr.rs +++ b/rust/tvm-rt/src/object/object_ptr.rs @@ -62,10 +62,12 @@ pub struct Object { /// "subtype". /// /// This function just converts the pointer to the correct type -/// and invokes the underlying typed delete function. +/// and reconstructs a Box which then is dropped to deallocate +/// the underlying allocation. unsafe extern "C" fn delete(object: *mut Object) { let typed_object: *mut T = object as *mut T; - T::typed_delete(typed_object); + let boxed: Box = Box::from_raw(typed_object); + drop(boxed); } fn derived_from(child_type_index: u32, parent_type_index: u32) -> bool { @@ -157,11 +159,6 @@ impl Object { /// to the subtype. pub unsafe trait IsObject: AsRef + std::fmt::Debug { const TYPE_KEY: &'static str; - - unsafe extern "C" fn typed_delete(object: *mut Self) { - let object = Box::from_raw(object); - drop(object) - } } /// A smart pointer for types which implement IsObject. @@ -332,6 +329,7 @@ impl<'a, T: IsObject> From<&'a ObjectPtr> for ArgValue<'a> { } _ => { let raw_ptr = unsafe { object_ptr.as_ptr() } as *mut std::ffi::c_void; + println!("raw_ptr {:?}", raw_ptr); assert!(!raw_ptr.is_null()); ArgValue::ObjectHandle(raw_ptr) } @@ -348,15 +346,24 @@ impl<'a, T: IsObject> TryFrom> for ObjectPtr { match arg_value { ArgValue::ObjectHandle(handle) | ArgValue::ModuleHandle(handle) => { + println!("handle: {:?}", handle); let optr = ObjectPtr::from_raw(handle as *mut Object).ok_or(Error::Null)?; - debug_assert!(optr.count() >= 1); + // We are building an owned, ref-counted view into the underlying ArgValue, in order to be safe we must + // bump the reference count by one. + optr.inc_ref(); + assert!(optr.count() >= 1); optr.downcast() } ArgValue::NDArrayHandle(handle) => { let optr = NDArrayContainer::from_raw(handle as *mut DLTensor).ok_or(Error::Null)?; - debug_assert!(optr.count() >= 1); - optr.upcast::().downcast() + // We are building an owned, ref-counted view into the underlying ArgValue, in order to be safe we must + // bump the reference count by one. + assert!(optr.count() >= 1); + // TODO(@jroesch): figure out if there is a more optimal way to do this + let object = optr.upcast::(); + object.inc_ref(); + object.downcast() } _ => Err(Error::downcast(format!("{:?}", arg_value), "ObjectHandle")), } @@ -444,11 +451,12 @@ mod tests { assert_eq!(ptr.count(), 1); let ptr_clone = ptr.clone(); assert_eq!(ptr.count(), 2); - let arg_value: ArgValue = ptr_clone.into(); + let arg_value: ArgValue = (&ptr_clone).into(); assert_eq!(ptr.count(), 2); let ptr2: ObjectPtr = arg_value.try_into()?; - assert_eq!(ptr2.count(), 2); + assert_eq!(ptr2.count(), 3); assert_eq!(ptr.count(), ptr2.count()); + drop(ptr_clone); assert_eq!(ptr.count(), 2); ensure!( ptr.type_index == ptr2.type_index, @@ -480,7 +488,7 @@ mod tests { assert_eq!(ptr.count(), 2); register(test_fn, "my_func2").unwrap(); let func = Function::get("my_func2").unwrap(); - let same = func.invoke(vec![ptr.into()]).unwrap(); + let same = func.invoke(vec![(&ptr).into()]).unwrap(); let same: ObjectPtr = same.try_into().unwrap(); // TODO(@jroesch): normalize RetValue ownership assert_eq!(same.count(), 2); drop(same); diff --git a/rust/tvm-rt/src/to_function.rs b/rust/tvm-rt/src/to_function.rs index 0a053b7af539..f46467eb7d0e 100644 --- a/rust/tvm-rt/src/to_function.rs +++ b/rust/tvm-rt/src/to_function.rs @@ -32,7 +32,6 @@ use std::{ }; use super::{function::Result, Function}; -use crate::AsArgValue; use crate::errors::Error; pub use tvm_sys::{ffi, ArgValue, RetValue}; @@ -45,23 +44,17 @@ pub use tvm_sys::{ffi, ArgValue, RetValue}; /// conversion of inputs and outputs to this trait. /// /// And the implementation of it to `ToFunction`. -pub trait Typed<'a, I, O> { - fn args(i: Vec>) -> Result; - fn ret(o: O) -> Result; -} -trait AsArgValueErased where Self: for<'a> AsArgValue<'a> { - fn as_arg_value<'a>(&'a self) -> ArgValue<'a>; -} +type ArgList<'a> = Vec>; -struct ArgList { - args: Vec> +pub enum Args<'a, I> { + Typed(I), + Raw(ArgList<'a>) } -impl AsArgValueErased for T where T: for<'a> AsArgValue<'a> { - fn as_arg_value<'a>(&'a self) -> ArgValue<'a> { - AsArgValue::as_arg_value(self) - } +pub trait Typed { + fn args<'arg>(i: Vec>) -> Result>; + fn ret(o: O) -> Result; } pub trait ToFunction: Sized { @@ -71,13 +64,13 @@ pub trait ToFunction: Sized { fn call<'a>(handle: *mut Self::Handle, args: Vec>) -> Result where - Self: for<'arg> Typed<'arg, I, O>; + Self: Typed; fn drop(handle: *mut Self::Handle); fn to_function(self) -> Function where - Self: for<'a> Typed<'a, I, O>, + Self: Typed, { let mut fhandle = ptr::null_mut() as ffi::TVMFunctionHandle; let resource_handle = self.into_raw(); @@ -102,7 +95,7 @@ pub trait ToFunction: Sized { resource_handle: *mut c_void, ) -> c_int where - Self: for <'a> Typed<'a, I, O>, + Self: Typed, { #![allow(unused_assignments, unused_unsafe)] let result = std::panic::catch_unwind(|| { @@ -180,56 +173,69 @@ pub trait ToFunction: Sized { } } -// impl<'a> Typed<'a, Vec>, RetValue> for for<'arg> fn(Vec>) -> Result { -// fn args(args: Vec>) -> Result>> { -// Ok(args) -// } +pub struct RawArgs; -// fn ret(o: RetValue) -> Result { -// Ok(o) -// } -// } +impl Typed for for <'a> fn(Vec>) -> Result { + fn args<'arg>(args: Vec>) -> Result> { + Ok(Args::Raw(args)) + } -// impl ToFunction -// for fn(ArgList) -> Result -// { -// type Handle = for<'a> fn(Vec>) -> Result; + fn ret(o: RetValue) -> Result { + Ok(o) + } +} -// fn into_raw(self) -> *mut Self::Handle { -// let ptr: Box = Box::new(self); -// Box::into_raw(ptr) -// } +impl ToFunction + for for <'arg> fn(Vec>) -> Result +{ + type Handle = for <'arg> fn(Vec>) -> Result; -// fn call<'a>(handle: *mut Self::Handle, args: Vec>) -> Result { -// unsafe { -// let func = (*handle); -// func(args) -// } -// } + fn into_raw(self) -> *mut Self::Handle { + let ptr: Box = Box::new(self); + Box::into_raw(ptr) + } -// fn drop(_: *mut Self::Handle) {} -// } + fn call<'arg>(handle: *mut Self::Handle, args: Vec>) -> Result { + unsafe { + let func = *handle; + func(args) + } + } + + fn drop(_: *mut Self::Handle) {} +} + +/// A helper trait which correctly captures the complex conversion and lifetime semantics needed +/// to coerce an ordinary Rust value into `ArgValue`. +pub trait TryFromArgValue: TryFrom { + fn from_arg_value(f: F) -> std::result::Result; +} + +impl<'a, T> TryFromArgValue> for T where Self: TryFrom>, Error: From<>>::Error> { + fn from_arg_value(f: ArgValue<'a>) -> std::result::Result { + Ok(TryFrom::try_from(f)?) + } +} macro_rules! impl_typed_and_to_function { ($len:literal; $($t:ident),*) => { - impl<'a, F, Out, $($t),*> Typed<'a, ($($t,)*), Out> for F + impl Typed<($($t,)*), Out> for Fun where - F: Fn($($t),*) -> Out, + Fun: Fn($($t),*) -> Out, Out: TryInto, Error: From, - $( $t: TryFrom>, - Error: From<<$t as TryFrom>>::Error>, )* + $( for<'a> $t: TryFromArgValue>, )* { #[allow(non_snake_case, unused_variables, unused_mut)] - fn args(args: Vec>) -> Result<($($t,)*)> { + fn args<'arg>(args: Vec>) -> Result> { if args.len() != $len { return Err(Error::CallFailed(format!("{} expected {} arguments, got {}.\n", std::any::type_name::(), $len, args.len()))) } let mut args = args.into_iter(); - $(let $t = args.next().unwrap().try_into()?;)* - Ok(($($t,)*)) + $(let $t = TryFromArgValue::from_arg_value(args.next().unwrap())?;)* + Ok(Args::Typed(($($t,)*))) } fn ret(out: Out) -> Result { @@ -238,9 +244,9 @@ macro_rules! impl_typed_and_to_function { } - impl ToFunction<($($t,)*), Out> for F + impl ToFunction<($($t,)*), Out> for Fun where - F: Fn($($t,)*) -> Out + 'static + Fun: Fn($($t,)*) -> Out + 'static { type Handle = Box Out + 'static>; @@ -252,11 +258,15 @@ macro_rules! impl_typed_and_to_function { #[allow(non_snake_case)] fn call<'a>(handle: *mut Self::Handle, args: Vec>) -> Result where - F: for<'arg> Typed<'arg, ($($t,)*), Out> + Fun: Typed<($($t,)*), Out> { - let ($($t,)*) = F::args(args)?; + let ($($t,)*) = match Fun::args(args)? { + Args::Raw(_) => panic!("impossible case"), + Args::Typed(typed) => typed, + }; + let out = unsafe { (*handle)($($t),*) }; - F::ret(out) + Fun::ret(out) } fn drop(ptr: *mut Self::Handle) { @@ -273,7 +283,9 @@ impl_typed_and_to_function!(2; A, B); impl_typed_and_to_function!(3; A, B, C); impl_typed_and_to_function!(4; A, B, C, D); impl_typed_and_to_function!(5; A, B, C, D, E); -impl_typed_and_to_function!(6; A, B, C, D, E, G); +impl_typed_and_to_function!(6; A, B, C, D, E, F); +impl_typed_and_to_function!(7; A, B, C, D, E, F, G); +impl_typed_and_to_function!(8; A, B, C, D, E, F, G, H); #[cfg(test)] mod tests { diff --git a/rust/tvm-sys/src/packed_func.rs b/rust/tvm-sys/src/packed_func.rs index e996b9ddf3b7..a74cbe318e2d 100644 --- a/rust/tvm-sys/src/packed_func.rs +++ b/rust/tvm-sys/src/packed_func.rs @@ -224,7 +224,7 @@ macro_rules! impl_pod_value { } } - impl<'a, 'v> From<&'a $type> for ArgValue<'v> { + impl<'a> From<&'a $type> for ArgValue<'a> { fn from(val: &'a $type) -> Self { Self::$variant(*val as $inner_ty) } From 109545a2e4aab48a856db8fc1110be100215238c Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Fri, 13 Aug 2021 17:33:52 -0700 Subject: [PATCH 05/13] WIP memory leak --- rust/tvm-rt/src/function.rs | 4 ++ rust/tvm-rt/src/object/object_ptr.rs | 65 ++++++++++++++++++++-------- rust/tvm-rt/src/to_function.rs | 8 ++-- rust/tvm-sys/src/packed_func.rs | 2 +- 4 files changed, 56 insertions(+), 23 deletions(-) diff --git a/rust/tvm-rt/src/function.rs b/rust/tvm-rt/src/function.rs index 91c90de6686d..178e1e63e45b 100644 --- a/rust/tvm-rt/src/function.rs +++ b/rust/tvm-rt/src/function.rs @@ -131,6 +131,10 @@ impl Function { ) }; + if ret_type_code == crate::ffi::TVMArgTypeCode_kTVMObjectRValueRefArg as _ { + panic!() + } + if ret_code != 0 { let raw_error = crate::get_last_error(); let error = match Error::from_raw_tvm(raw_error) { diff --git a/rust/tvm-rt/src/object/object_ptr.rs b/rust/tvm-rt/src/object/object_ptr.rs index 59a4347bfe80..d6afd543abd5 100644 --- a/rust/tvm-rt/src/object/object_ptr.rs +++ b/rust/tvm-rt/src/object/object_ptr.rs @@ -20,11 +20,12 @@ use std::convert::TryFrom; use std::ffi::CString; use std::fmt; +use std::os::raw::c_char; use std::ptr::NonNull; use std::sync::atomic::AtomicI32; use tvm_macros::Object; -use tvm_sys::ffi::{self, TVMObjectFree, TVMObjectRetain, TVMObjectTypeKey2Index}; +use tvm_sys::ffi::{self, TVMObjectFree, TVMObjectRetain, TVMObjectTypeKey2Index, TVMObjectTypeIndex2Key}; use tvm_sys::{ArgValue, RetValue}; use crate::errors::Error; @@ -100,6 +101,18 @@ impl Object { } } + fn get_type_key(&self) -> String { + let mut cstring: *mut c_char = std::ptr::null_mut(); + unsafe { + if TVMObjectTypeIndex2Key(self.type_index, &mut cstring as *mut _) != 0 { + panic!("{}", crate::get_last_error()); + } + return CString::from_raw(cstring) + .into_string() + .expect("type keys should be valid utf-8"); + } + } + fn get_type_index() -> u32 { let type_key = T::TYPE_KEY; let cstring = CString::new(type_key).expect("type key must not contain null characters"); @@ -249,7 +262,8 @@ impl ObjectPtr { if is_derived { Ok(unsafe { self.cast() }) } else { - Err(Error::downcast("TODOget_type_key".into(), U::TYPE_KEY)) + let type_key = self.as_ref().get_type_key(); + Err(Error::downcast(type_key.into(), U::TYPE_KEY)) } } @@ -317,8 +331,7 @@ impl<'a, T: IsObject> From<&'a ObjectPtr> for ArgValue<'a> { "runtime.NDArray" => { use crate::ndarray::NDArrayContainer; let dcast_ptr = object_ptr.downcast().unwrap(); - let raw_ptr = NDArrayContainer::as_mut_ptr(&dcast_ptr) - as *mut std::ffi::c_void; + let raw_ptr = NDArrayContainer::as_mut_ptr(&dcast_ptr) as *mut std::ffi::c_void; assert!(!raw_ptr.is_null()); ArgValue::NDArrayHandle(raw_ptr) } @@ -329,7 +342,6 @@ impl<'a, T: IsObject> From<&'a ObjectPtr> for ArgValue<'a> { } _ => { let raw_ptr = unsafe { object_ptr.as_ptr() } as *mut std::ffi::c_void; - println!("raw_ptr {:?}", raw_ptr); assert!(!raw_ptr.is_null()); ArgValue::ObjectHandle(raw_ptr) } @@ -346,11 +358,10 @@ impl<'a, T: IsObject> TryFrom> for ObjectPtr { match arg_value { ArgValue::ObjectHandle(handle) | ArgValue::ModuleHandle(handle) => { - println!("handle: {:?}", handle); let optr = ObjectPtr::from_raw(handle as *mut Object).ok_or(Error::Null)?; // We are building an owned, ref-counted view into the underlying ArgValue, in order to be safe we must // bump the reference count by one. - optr.inc_ref(); + // optr.inc_ref(); assert!(optr.count() >= 1); optr.downcast() } @@ -362,7 +373,7 @@ impl<'a, T: IsObject> TryFrom> for ObjectPtr { assert!(optr.count() >= 1); // TODO(@jroesch): figure out if there is a more optimal way to do this let object = optr.upcast::(); - object.inc_ref(); + // object.inc_ref(); object.downcast() } _ => Err(Error::downcast(format!("{:?}", arg_value), "ObjectHandle")), @@ -473,25 +484,43 @@ mod tests { } fn test_fn(o: ObjectPtr) -> ObjectPtr { - // The call machinery adds at least 1 extra count while inside the call. - assert_eq!(o.count(), 3); + assert_eq!(o.count(), 2); return o; } + fn test_fn_raw<'a>(mut args: crate::to_function::ArgList<'a>) -> crate::function::Result { + let v: ArgValue = args.remove(0); + let v2: ArgValue = args.remove(0); + // assert_eq!(o.count(), 2); + let o: ObjectPtr = v.try_into().unwrap(); + let o2: ObjectPtr = v2.try_into().unwrap(); + assert_eq!(o.count(), 3); + assert_eq!(o2.count(), 3); + // return o; + Ok(o.into()) + } + + #[test] fn test_ref_count_boundary3() { use super::*; - use crate::function::{register, Function}; + use crate::function::{register, register_untyped, Function}; + use crate::to_function::ToFunction; let ptr = ObjectPtr::new(Object::base::()); assert_eq!(ptr.count(), 1); - let stay = ptr.clone(); - assert_eq!(ptr.count(), 2); - register(test_fn, "my_func2").unwrap(); - let func = Function::get("my_func2").unwrap(); - let same = func.invoke(vec![(&ptr).into()]).unwrap(); + register_untyped(test_fn_raw, "foo", true); + let raw_func = Function::get("foo").unwrap(); + let same = raw_func.invoke(vec![(&ptr).into(), (&ptr).into()]).unwrap(); let same: ObjectPtr = same.try_into().unwrap(); - // TODO(@jroesch): normalize RetValue ownership assert_eq!(same.count(), 2); drop(same); - assert_eq!(stay.count(), 3); + drop(raw_func); + // let func = test_fn.to_function(); + // let same = func.invoke(vec![(&ptr).into()]).unwrap(); + // drop(same); + // let same = func.invoke(vec![(&ptr).into()]).unwrap(); + // let same: ObjectPtr = same.try_into().unwrap(); + // drop(same); + // drop(func); + assert_eq!(ptr.count(), 1); } } diff --git a/rust/tvm-rt/src/to_function.rs b/rust/tvm-rt/src/to_function.rs index f46467eb7d0e..57dac2af1525 100644 --- a/rust/tvm-rt/src/to_function.rs +++ b/rust/tvm-rt/src/to_function.rs @@ -45,7 +45,7 @@ pub use tvm_sys::{ffi, ArgValue, RetValue}; /// /// And the implementation of it to `ToFunction`. -type ArgList<'a> = Vec>; +pub type ArgList<'a> = Vec>; pub enum Args<'a, I> { Typed(I), @@ -78,7 +78,7 @@ pub trait ToFunction: Sized { check_call!(ffi::TVMFuncCreateFromCFunc( Some(Self::tvm_callback), resource_handle as *mut _, - None, // Some(Self::tvm_finalizer), + Some(Self::tvm_finalizer), &mut fhandle as *mut ffi::TVMFunctionHandle, )); @@ -125,7 +125,6 @@ pub trait ToFunction: Sized { local_args.push(arg_value); } - // Ref-count be 2. let rv = match Self::call(resource_handle, local_args) { Ok(v) => v, Err(msg) => { @@ -265,7 +264,8 @@ macro_rules! impl_typed_and_to_function { Args::Typed(typed) => typed, }; - let out = unsafe { (*handle)($($t),*) }; + let fn_ptr = unsafe { &*handle }; + let out = fn_ptr($($t),*); Fun::ret(out) } diff --git a/rust/tvm-sys/src/packed_func.rs b/rust/tvm-sys/src/packed_func.rs index a74cbe318e2d..c8a7e8cbade5 100644 --- a/rust/tvm-sys/src/packed_func.rs +++ b/rust/tvm-sys/src/packed_func.rs @@ -68,7 +68,7 @@ macro_rules! TVMPODValue { $(,)? } => { $(#[$m])+ - #[derive(Clone, Debug)] + #[derive(Debug)] pub enum $name $(<$a>)? { Int(i64), UInt(i64), From 4da8326626b55159d7ee570fae3b88249bc92bb6 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Fri, 13 Aug 2021 18:25:50 -0700 Subject: [PATCH 06/13] There is issue using TVMCb function which is breaking refcount --- rust/tvm-rt/src/function.rs | 12 ++++++++- rust/tvm-rt/src/object/object_ptr.rs | 34 +++++++++++++++++++---- rust/tvm-rt/src/to_function.rs | 40 ++++++++++++++++++++-------- rust/tvm-sys/src/packed_func.rs | 2 +- 4 files changed, 70 insertions(+), 18 deletions(-) diff --git a/rust/tvm-rt/src/function.rs b/rust/tvm-rt/src/function.rs index 178e1e63e45b..45dfa6c296f2 100644 --- a/rust/tvm-rt/src/function.rs +++ b/rust/tvm-rt/src/function.rs @@ -115,7 +115,12 @@ impl Function { pub fn invoke<'a>(&self, arg_buf: Vec>) -> Result { let num_args = arg_buf.len(); let (mut values, mut type_codes): (Vec, Vec) = - arg_buf.into_iter().map(|arg| arg.to_tvm_value()).unzip(); + arg_buf.clone().into_iter().map(|arg| arg.to_tvm_value()).unzip(); + + for arg in arg_buf.clone() { + let oref: crate::object::ObjectPtr = arg.clone().try_into().unwrap(); + println!("oref: {:?}", oref.count()); + } let mut ret_val = ffi::TVMValue { v_int64: 0 }; let mut ret_type_code = 0i32; @@ -144,6 +149,11 @@ impl Function { return Err(error); } + for arg in &arg_buf { + let oref: crate::object::ObjectPtr = arg.clone().try_into().unwrap(); + println!("oref: {:?}", oref.count()); + } + let rv = RetValue::from_tvm_value(ret_val, ret_type_code as u32); Ok(rv) diff --git a/rust/tvm-rt/src/object/object_ptr.rs b/rust/tvm-rt/src/object/object_ptr.rs index d6afd543abd5..cf5548f65064 100644 --- a/rust/tvm-rt/src/object/object_ptr.rs +++ b/rust/tvm-rt/src/object/object_ptr.rs @@ -361,7 +361,7 @@ impl<'a, T: IsObject> TryFrom> for ObjectPtr { let optr = ObjectPtr::from_raw(handle as *mut Object).ok_or(Error::Null)?; // We are building an owned, ref-counted view into the underlying ArgValue, in order to be safe we must // bump the reference count by one. - // optr.inc_ref(); + optr.inc_ref(); assert!(optr.count() >= 1); optr.downcast() } @@ -373,7 +373,7 @@ impl<'a, T: IsObject> TryFrom> for ObjectPtr { assert!(optr.count() >= 1); // TODO(@jroesch): figure out if there is a more optimal way to do this let object = optr.upcast::(); - // object.inc_ref(); + object.inc_ref(); object.downcast() } _ => Err(Error::downcast(format!("{:?}", arg_value), "ObjectHandle")), @@ -493,13 +493,36 @@ mod tests { let v2: ArgValue = args.remove(0); // assert_eq!(o.count(), 2); let o: ObjectPtr = v.try_into().unwrap(); + assert_eq!(o.count(), 2); let o2: ObjectPtr = v2.try_into().unwrap(); - assert_eq!(o.count(), 3); assert_eq!(o2.count(), 3); - // return o; + drop(o2); + assert_eq!(o.count(), 2); Ok(o.into()) } + #[test] + fn test_ref_count_raw_fn() { + use super::*; + use crate::function::{register, register_untyped, Function}; + use crate::to_function::ToFunction; + let ptr = ObjectPtr::new(Object::base::()); + // Call the function without the wrapping for TVM. + assert_eq!(ptr.count(), 1); + let same = test_fn_raw(vec![(&ptr).into(), (&ptr).into()]).unwrap(); + let output: ObjectPtr = same.try_into().unwrap(); + assert_eq!(output.count(), 2); + drop(output); + assert_eq!(ptr.count(), 1); + + register_untyped(test_fn_raw, "test_fn_raw", true).unwrap(); + let raw_func = Function::get("test_fn_raw").unwrap(); + let output = raw_func.invoke(vec![(&ptr).into(), (&ptr).into()]).unwrap(); + let output: ObjectPtr = output.try_into().unwrap(); + assert_eq!(output.count(), 2); + drop(output); + assert_eq!(ptr.count(), 1); + } #[test] fn test_ref_count_boundary3() { @@ -510,7 +533,8 @@ mod tests { assert_eq!(ptr.count(), 1); register_untyped(test_fn_raw, "foo", true); let raw_func = Function::get("foo").unwrap(); - let same = raw_func.invoke(vec![(&ptr).into(), (&ptr).into()]).unwrap(); + // let same = raw_func.invoke(vec![(&ptr).into(), (&ptr).into()]).unwrap(); + let same = test_fn_raw(vec![(&ptr).into(), (&ptr).into()]).unwrap(); let same: ObjectPtr = same.try_into().unwrap(); drop(same); drop(raw_func); diff --git a/rust/tvm-rt/src/to_function.rs b/rust/tvm-rt/src/to_function.rs index 57dac2af1525..e3513eec7259 100644 --- a/rust/tvm-rt/src/to_function.rs +++ b/rust/tvm-rt/src/to_function.rs @@ -110,21 +110,26 @@ pub trait ToFunction: Sized { for i in 0..len { value = args_list[i]; tcode = type_codes_list[i]; - if tcode == ffi::TVMArgTypeCode_kTVMObjectHandle as c_int - || tcode == ffi::TVMArgTypeCode_kTVMObjectRValueRefArg as c_int - || tcode == ffi::TVMArgTypeCode_kTVMPackedFuncHandle as c_int - || tcode == ffi::TVMArgTypeCode_kTVMModuleHandle as c_int - || tcode == ffi::TVMArgTypeCode_kTVMNDArrayHandle as c_int - { - check_call!(ffi::TVMCbArgToReturn( - &mut value as *mut _, - &mut tcode as *mut _ - )); - } + // if tcode == ffi::TVMArgTypeCode_kTVMObjectHandle as c_int + // || tcode == ffi::TVMArgTypeCode_kTVMObjectRValueRefArg as c_int + // || tcode == ffi::TVMArgTypeCode_kTVMPackedFuncHandle as c_int + // || tcode == ffi::TVMArgTypeCode_kTVMModuleHandle as c_int + // || tcode == ffi::TVMArgTypeCode_kTVMNDArrayHandle as c_int + // { + // check_call!(ffi::TVMCbArgToReturn( + // &mut value as *mut _, + // &mut tcode as *mut _ + // )); + // } let arg_value = ArgValue::from_tvm_value(value, tcode as u32); local_args.push(arg_value); } + for arg in local_args.clone() { + let oref: crate::object::ObjectPtr = arg.clone().try_into().unwrap(); + println!("right before call oref: {:?}", oref.count()); + } + let rv = match Self::call(resource_handle, local_args) { Ok(v) => v, Err(msg) => { @@ -132,6 +137,9 @@ pub trait ToFunction: Sized { } }; + let oref: crate::object::ObjectPtr = rv.clone().try_into().unwrap(); + // println!("ret value oref: {:?}", oref.count()); + let (mut ret_val, ret_tcode) = rv.to_tvm_value(); let mut ret_type_code = ret_tcode as c_int; @@ -176,6 +184,11 @@ pub struct RawArgs; impl Typed for for <'a> fn(Vec>) -> Result { fn args<'arg>(args: Vec>) -> Result> { + for arg in args.clone() { + let oref: crate::object::ObjectPtr = arg.clone().try_into().unwrap(); + println!("args oref: {:?}", oref.count()); + } + Ok(Args::Raw(args)) } @@ -195,6 +208,11 @@ impl ToFunction } fn call<'arg>(handle: *mut Self::Handle, args: Vec>) -> Result { + for arg in args.clone() { + let oref: crate::object::ObjectPtr = arg.clone().try_into().unwrap(); + println!("call oref: {:?}", oref.count()); + } + unsafe { let func = *handle; func(args) diff --git a/rust/tvm-sys/src/packed_func.rs b/rust/tvm-sys/src/packed_func.rs index c8a7e8cbade5..a74cbe318e2d 100644 --- a/rust/tvm-sys/src/packed_func.rs +++ b/rust/tvm-sys/src/packed_func.rs @@ -68,7 +68,7 @@ macro_rules! TVMPODValue { $(,)? } => { $(#[$m])+ - #[derive(Debug)] + #[derive(Clone, Debug)] pub enum $name $(<$a>)? { Int(i64), UInt(i64), From 9c308e786fb87f12a9279befa8eb3d6ffa602454 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Mon, 16 Aug 2021 20:51:50 -0700 Subject: [PATCH 07/13] Fix fallout from the lifetime refactor --- rust/tvm-rt/src/function.rs | 25 ++++------- rust/tvm-rt/src/map.rs | 37 ++++++++++++++++ rust/tvm-rt/src/object/mod.rs | 5 ++- rust/tvm-rt/src/object/object_ptr.rs | 54 ++++++++++++----------- rust/tvm-rt/src/to_function.rs | 41 +++++++---------- rust/tvm-sys/src/byte_array.rs | 4 +- rust/tvm/src/compiler/graph_rt.rs | 10 ++--- rust/tvm/src/ir/module.rs | 14 +++--- rust/tvm/tests/callback/src/bin/array.rs | 2 +- rust/tvm/tests/callback/src/bin/error.rs | 2 +- rust/tvm/tests/callback/src/bin/float.rs | 2 +- rust/tvm/tests/callback/src/bin/int.rs | 2 +- rust/tvm/tests/callback/src/bin/string.rs | 2 +- src/runtime/object.cc | 2 +- 14 files changed, 114 insertions(+), 88 deletions(-) diff --git a/rust/tvm-rt/src/function.rs b/rust/tvm-rt/src/function.rs index 45dfa6c296f2..450aba578377 100644 --- a/rust/tvm-rt/src/function.rs +++ b/rust/tvm-rt/src/function.rs @@ -35,9 +35,9 @@ use std::{ use crate::errors::Error; -pub use super::to_function::{ToFunction, Typed, RawArgs}; -pub use tvm_sys::{ffi, ArgValue, RetValue}; +pub use super::to_function::{RawArgs, ToFunction, Typed}; use crate::object::AsArgValue; +pub use tvm_sys::{ffi, ArgValue, RetValue}; pub type Result = std::result::Result; @@ -114,13 +114,11 @@ impl Function { /// Calls the function that created from `Builder`. pub fn invoke<'a>(&self, arg_buf: Vec>) -> Result { let num_args = arg_buf.len(); - let (mut values, mut type_codes): (Vec, Vec) = - arg_buf.clone().into_iter().map(|arg| arg.to_tvm_value()).unzip(); - - for arg in arg_buf.clone() { - let oref: crate::object::ObjectPtr = arg.clone().try_into().unwrap(); - println!("oref: {:?}", oref.count()); - } + let (mut values, mut type_codes): (Vec, Vec) = arg_buf + .clone() + .into_iter() + .map(|arg| arg.to_tvm_value()) + .unzip(); let mut ret_val = ffi::TVMValue { v_int64: 0 }; let mut ret_type_code = 0i32; @@ -149,11 +147,6 @@ impl Function { return Err(error); } - for arg in &arg_buf { - let oref: crate::object::ObjectPtr = arg.clone().try_into().unwrap(); - println!("oref: {:?}", oref.count()); - } - let rv = RetValue::from_tvm_value(ret_val, ret_type_code as u32); Ok(rv) @@ -211,8 +204,8 @@ impl TryFrom for Function { } } -impl<'a> From for ArgValue<'a> { - fn from(func: Function) -> ArgValue<'a> { +impl<'a> From<&Function> for ArgValue<'a> { + fn from(func: &Function) -> ArgValue<'a> { if func.handle().is_null() { ArgValue::Null } else { diff --git a/rust/tvm-rt/src/map.rs b/rust/tvm-rt/src/map.rs index 5594a91dc0f0..8ae5d06a28bf 100644 --- a/rust/tvm-rt/src/map.rs +++ b/rust/tvm-rt/src/map.rs @@ -121,6 +121,43 @@ where } } +// pub struct Iter<'a, K, V> { +// // NB: due to FFI this isn't as lazy as one might like +// key_and_values: Array, +// next_key: i64, +// _data: PhantomData<(&'a K, &'a V)>, +// } + +// impl<'a, K, V> Iterator for Iter<'a, K, V> +// where +// K: IsObjectRef, +// V: IsObjectRef, +// { +// type Item = (&'a K, &'a V); + +// #[inline] +// fn next(&mut self) -> Option<(&'a K, &'a V)> { +// if self.next_key < self.key_and_values.len() { +// let key = self +// .key_and_values +// .get(self.next_key as isize) +// .expect("this should always succeed"); +// let value = self +// .key_and_values +// .get((self.next_key as isize) + 1) +// .expect("this should always succeed"); +// self.next_key += 2; +// Some((key.downcast().unwrap(), value.downcast().unwrap())) +// } else { +// None +// } +// } + +// #[inline] +// fn size_hint(&self) -> (usize, Option) { +// ((self.key_and_values.len() / 2) as usize, None) +// } +// } pub struct IntoIter { // NB: due to FFI this isn't as lazy as one might like key_and_values: Array, diff --git a/rust/tvm-rt/src/object/mod.rs b/rust/tvm-rt/src/object/mod.rs index 075ef46f35e2..f5832fcb3ab8 100644 --- a/rust/tvm-rt/src/object/mod.rs +++ b/rust/tvm-rt/src/object/mod.rs @@ -33,7 +33,10 @@ pub trait AsArgValue<'a> { fn as_arg_value(&'a self) -> ArgValue<'a>; } -impl<'a, T: 'static> AsArgValue<'a> for T where &'a T: Into> { +impl<'a, T: 'static> AsArgValue<'a> for T +where + &'a T: Into>, +{ fn as_arg_value(&'a self) -> ArgValue<'a> { self.into() } diff --git a/rust/tvm-rt/src/object/object_ptr.rs b/rust/tvm-rt/src/object/object_ptr.rs index cf5548f65064..2f4ff175e999 100644 --- a/rust/tvm-rt/src/object/object_ptr.rs +++ b/rust/tvm-rt/src/object/object_ptr.rs @@ -25,7 +25,9 @@ use std::ptr::NonNull; use std::sync::atomic::AtomicI32; use tvm_macros::Object; -use tvm_sys::ffi::{self, TVMObjectFree, TVMObjectRetain, TVMObjectTypeKey2Index, TVMObjectTypeIndex2Key}; +use tvm_sys::ffi::{ + self, TVMObjectFree, TVMObjectRetain, TVMObjectTypeIndex2Key, TVMObjectTypeKey2Index, +}; use tvm_sys::{ArgValue, RetValue}; use crate::errors::Error; @@ -483,12 +485,9 @@ mod tests { Ok(()) } - fn test_fn(o: ObjectPtr) -> ObjectPtr { - assert_eq!(o.count(), 2); - return o; - } - - fn test_fn_raw<'a>(mut args: crate::to_function::ArgList<'a>) -> crate::function::Result { + fn test_fn_raw<'a>( + mut args: crate::to_function::ArgList<'a>, + ) -> crate::function::Result { let v: ArgValue = args.remove(0); let v2: ArgValue = args.remove(0); // assert_eq!(o.count(), 2); @@ -504,8 +503,7 @@ mod tests { #[test] fn test_ref_count_raw_fn() { use super::*; - use crate::function::{register, register_untyped, Function}; - use crate::to_function::ToFunction; + use crate::function::{register_untyped, Function}; let ptr = ObjectPtr::new(Object::base::()); // Call the function without the wrapping for TVM. assert_eq!(ptr.count(), 1); @@ -524,27 +522,33 @@ mod tests { assert_eq!(ptr.count(), 1); } + fn test_fn_typed(o: ObjectPtr, o2: ObjectPtr) -> ObjectPtr { + assert_eq!(o.count(), 3); + assert_eq!(o2.count(), 3); + drop(o2); + assert_eq!(o.count(), 2); + return o; + } + #[test] - fn test_ref_count_boundary3() { + fn test_ref_count_typed() { use super::*; - use crate::function::{register, register_untyped, Function}; - use crate::to_function::ToFunction; + use crate::function::{register, Function}; let ptr = ObjectPtr::new(Object::base::()); + // Call the function without the wrapping for TVM. assert_eq!(ptr.count(), 1); - register_untyped(test_fn_raw, "foo", true); - let raw_func = Function::get("foo").unwrap(); - // let same = raw_func.invoke(vec![(&ptr).into(), (&ptr).into()]).unwrap(); let same = test_fn_raw(vec![(&ptr).into(), (&ptr).into()]).unwrap(); - let same: ObjectPtr = same.try_into().unwrap(); - drop(same); - drop(raw_func); - // let func = test_fn.to_function(); - // let same = func.invoke(vec![(&ptr).into()]).unwrap(); - // drop(same); - // let same = func.invoke(vec![(&ptr).into()]).unwrap(); - // let same: ObjectPtr = same.try_into().unwrap(); - // drop(same); - // drop(func); + let output: ObjectPtr = same.try_into().unwrap(); + assert_eq!(output.count(), 2); + drop(output); + assert_eq!(ptr.count(), 1); + + register(test_fn_typed, "test_fn_typed").unwrap(); + let raw_func = Function::get("test_fn_typed").unwrap(); + let output = raw_func.invoke(vec![(&ptr).into(), (&ptr).into()]).unwrap(); + let output: ObjectPtr = output.try_into().unwrap(); + assert_eq!(output.count(), 2); + drop(output); assert_eq!(ptr.count(), 1); } } diff --git a/rust/tvm-rt/src/to_function.rs b/rust/tvm-rt/src/to_function.rs index e3513eec7259..523eb70f1598 100644 --- a/rust/tvm-rt/src/to_function.rs +++ b/rust/tvm-rt/src/to_function.rs @@ -49,7 +49,7 @@ pub type ArgList<'a> = Vec>; pub enum Args<'a, I> { Typed(I), - Raw(ArgList<'a>) + Raw(ArgList<'a>), } pub trait Typed { @@ -110,6 +110,13 @@ pub trait ToFunction: Sized { for i in 0..len { value = args_list[i]; tcode = type_codes_list[i]; + // TODO(@jroesch): I believe it is sound to disable this specialized move rule. + // + // This is used in C++ to deal with moving an RValue or reference to a return value + // directly so you can skip copying. + // + // I believe this is not needed as the move directly occurs into the Rust function. + // if tcode == ffi::TVMArgTypeCode_kTVMObjectHandle as c_int // || tcode == ffi::TVMArgTypeCode_kTVMObjectRValueRefArg as c_int // || tcode == ffi::TVMArgTypeCode_kTVMPackedFuncHandle as c_int @@ -125,11 +132,6 @@ pub trait ToFunction: Sized { local_args.push(arg_value); } - for arg in local_args.clone() { - let oref: crate::object::ObjectPtr = arg.clone().try_into().unwrap(); - println!("right before call oref: {:?}", oref.count()); - } - let rv = match Self::call(resource_handle, local_args) { Ok(v) => v, Err(msg) => { @@ -137,9 +139,6 @@ pub trait ToFunction: Sized { } }; - let oref: crate::object::ObjectPtr = rv.clone().try_into().unwrap(); - // println!("ret value oref: {:?}", oref.count()); - let (mut ret_val, ret_tcode) = rv.to_tvm_value(); let mut ret_type_code = ret_tcode as c_int; @@ -182,13 +181,8 @@ pub trait ToFunction: Sized { pub struct RawArgs; -impl Typed for for <'a> fn(Vec>) -> Result { +impl Typed for for<'a> fn(Vec>) -> Result { fn args<'arg>(args: Vec>) -> Result> { - for arg in args.clone() { - let oref: crate::object::ObjectPtr = arg.clone().try_into().unwrap(); - println!("args oref: {:?}", oref.count()); - } - Ok(Args::Raw(args)) } @@ -197,10 +191,8 @@ impl Typed for for <'a> fn(Vec>) -> Result - for for <'arg> fn(Vec>) -> Result -{ - type Handle = for <'arg> fn(Vec>) -> Result; +impl ToFunction for for<'arg> fn(Vec>) -> Result { + type Handle = for<'arg> fn(Vec>) -> Result; fn into_raw(self) -> *mut Self::Handle { let ptr: Box = Box::new(self); @@ -208,11 +200,6 @@ impl ToFunction } fn call<'arg>(handle: *mut Self::Handle, args: Vec>) -> Result { - for arg in args.clone() { - let oref: crate::object::ObjectPtr = arg.clone().try_into().unwrap(); - println!("call oref: {:?}", oref.count()); - } - unsafe { let func = *handle; func(args) @@ -228,7 +215,11 @@ pub trait TryFromArgValue: TryFrom { fn from_arg_value(f: F) -> std::result::Result; } -impl<'a, T> TryFromArgValue> for T where Self: TryFrom>, Error: From<>>::Error> { +impl<'a, T> TryFromArgValue> for T +where + Self: TryFrom>, + Error: From<>>::Error>, +{ fn from_arg_value(f: ArgValue<'a>) -> std::result::Result { Ok(TryFrom::try_from(f)?) } diff --git a/rust/tvm-sys/src/byte_array.rs b/rust/tvm-sys/src/byte_array.rs index 7da6145797f7..f4341aebe762 100644 --- a/rust/tvm-sys/src/byte_array.rs +++ b/rust/tvm-sys/src/byte_array.rs @@ -39,8 +39,7 @@ pub struct ByteArray { } impl Drop for ByteArray { - fn drop(&mut self) { - } + fn drop(&mut self) {} } impl ByteArray { @@ -64,7 +63,6 @@ impl ByteArray { } } - // Needs AsRef for Vec impl> From for ByteArray { fn from(arg: T) -> Self { diff --git a/rust/tvm/src/compiler/graph_rt.rs b/rust/tvm/src/compiler/graph_rt.rs index 6b5873398cab..8313e47bea20 100644 --- a/rust/tvm/src/compiler/graph_rt.rs +++ b/rust/tvm/src/compiler/graph_rt.rs @@ -51,11 +51,11 @@ fn _compile_module( ) -> Result { // The RAW API is Fn(IRModule, String, String, Map, String); let module = TVM_BUILD.invoke(vec![ - module.into(), - target.into(), - target_host.into(), - params.into(), - module_name.into(), + (&module).into(), + (&target).into(), + (&target_host).into(), + (¶ms).into(), + (&module_name).into(), ])?; let module: RtModule = module.try_into().unwrap(); Ok(module) diff --git a/rust/tvm/src/ir/module.rs b/rust/tvm/src/ir/module.rs index 513a906f6db4..ea257af1ebc0 100644 --- a/rust/tvm/src/ir/module.rs +++ b/rust/tvm/src/ir/module.rs @@ -99,10 +99,10 @@ external! { // Note: we don't expose update here as update is going to be removed. impl IRModule { - pub fn new(funcs: F, types: T) -> Result + pub fn new<'a, F, T>(funcs: F, types: T) -> Result where - F: IntoIterator, - T: IntoIterator, + F: IntoIterator, + T: IntoIterator, { module_new(Map::from_iter(funcs), Map::from_iter(types)) } @@ -110,7 +110,7 @@ impl IRModule { pub fn empty() -> Result { let funcs = HashMap::::new(); let types = HashMap::::new(); - IRModule::new(funcs, types) + IRModule::new(funcs.iter(), types.iter()) } pub fn parse(file_name: N, source: S) -> Result @@ -206,10 +206,10 @@ impl IRModule { Self::from_expr_with_items(expr, HashMap::new(), HashMap::new()) } - pub fn from_expr_with_items(expr: E, funcs: F, types: T) -> Result + pub fn from_expr_with_items<'a, E, F, T>(expr: E, funcs: F, types: T) -> Result where - F: IntoIterator, - T: IntoIterator, + F: IntoIterator, + T: IntoIterator, E: IsObjectRef, E::Object: AsRef<::Object>, { diff --git a/rust/tvm/tests/callback/src/bin/array.rs b/rust/tvm/tests/callback/src/bin/array.rs index 81ee426d3967..8deae30c076d 100644 --- a/rust/tvm/tests/callback/src/bin/array.rs +++ b/rust/tvm/tests/callback/src/bin/array.rs @@ -35,7 +35,7 @@ use tvm::{ }; fn main() { - fn sum(args: Vec>) -> Result { + fn sum<'a>(args: Vec>) -> Result { let mut ret = 0.0; for arg in args { let arg: NDArray = arg.try_into()?; diff --git a/rust/tvm/tests/callback/src/bin/error.rs b/rust/tvm/tests/callback/src/bin/error.rs index 37027af0ca37..f8886a55c3a2 100644 --- a/rust/tvm/tests/callback/src/bin/error.rs +++ b/rust/tvm/tests/callback/src/bin/error.rs @@ -26,7 +26,7 @@ use tvm::{ }; fn main() { - fn error(_args: Vec>) -> Result { + fn error<'a>(_args: Vec>) -> Result { Err(errors::NDArrayError::DataTypeMismatch { expected: DataType::int(64, 1), actual: DataType::float(64, 1), diff --git a/rust/tvm/tests/callback/src/bin/float.rs b/rust/tvm/tests/callback/src/bin/float.rs index 6fd4f868dc79..d575f47c87cd 100644 --- a/rust/tvm/tests/callback/src/bin/float.rs +++ b/rust/tvm/tests/callback/src/bin/float.rs @@ -27,7 +27,7 @@ use tvm::{ }; fn main() { - fn sum(args: Vec>) -> Result { + fn sum<'a>(args: Vec>) -> Result { let mut ret = 0.0; for arg in args.into_iter() { let val: f64 = arg.try_into()?; diff --git a/rust/tvm/tests/callback/src/bin/int.rs b/rust/tvm/tests/callback/src/bin/int.rs index cdea2e1044c4..fc2e40d8de4d 100644 --- a/rust/tvm/tests/callback/src/bin/int.rs +++ b/rust/tvm/tests/callback/src/bin/int.rs @@ -25,7 +25,7 @@ use tvm::{ }; fn main() { - fn sum(args: Vec>) -> Result { + fn sum<'a>(args: Vec>) -> Result { let mut ret = 0i64; for arg in args.iter() { let val: i64 = arg.try_into()?; diff --git a/rust/tvm/tests/callback/src/bin/string.rs b/rust/tvm/tests/callback/src/bin/string.rs index dbe65ba4c631..4f3d67e95d64 100644 --- a/rust/tvm/tests/callback/src/bin/string.rs +++ b/rust/tvm/tests/callback/src/bin/string.rs @@ -26,7 +26,7 @@ use tvm::{ // FIXME fn main() { - fn concat_str(args: Vec>) -> Result { + fn concat_str<'a>(args: Vec>) -> Result { let mut ret = "".to_string(); for arg in args.iter() { let val: &str = arg.try_into()?; diff --git a/src/runtime/object.cc b/src/runtime/object.cc index f89d805da136..3cd5df613f4a 100644 --- a/src/runtime/object.cc +++ b/src/runtime/object.cc @@ -265,7 +265,7 @@ int TVMObjectTypeKey2Index(const char* type_key, unsigned* out_tindex) { int TVMObjectTypeIndex2Key(unsigned tindex, char** out_type_key) { API_BEGIN(); - auto key = tvm::runtime::Object::TypeIndex2Key(tindex); + auto key = tvm::runtime::Object::TypeIndex2Key(tindex); *out_type_key = static_cast(malloc(key.size() + 1)); strncpy(*out_type_key, key.c_str(), key.size()); API_END(); From dfb0add1fe97424f1bf228c8089f4cec0b3bc814 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Mon, 16 Aug 2021 21:42:01 -0700 Subject: [PATCH 08/13] Another tweak --- rust/tvm-rt/src/object/object_ptr.rs | 65 ++++++++++++++-------------- rust/tvm-rt/src/to_function.rs | 5 +++ 2 files changed, 38 insertions(+), 32 deletions(-) diff --git a/rust/tvm-rt/src/object/object_ptr.rs b/rust/tvm-rt/src/object/object_ptr.rs index 2f4ff175e999..58a4040025b9 100644 --- a/rust/tvm-rt/src/object/object_ptr.rs +++ b/rust/tvm-rt/src/object/object_ptr.rs @@ -229,7 +229,7 @@ impl ObjectPtr { // ABI compatible atomics is funky/hard. self.as_ref() .ref_count - .load(std::sync::atomic::Ordering::Relaxed) + .load(std::sync::atomic::Ordering::Acquire) } /// This method avoid running the destructor on self once it's dropped, so we don't accidentally release the memory @@ -361,9 +361,9 @@ impl<'a, T: IsObject> TryFrom> for ObjectPtr { match arg_value { ArgValue::ObjectHandle(handle) | ArgValue::ModuleHandle(handle) => { let optr = ObjectPtr::from_raw(handle as *mut Object).ok_or(Error::Null)?; + optr.inc_ref(); // We are building an owned, ref-counted view into the underlying ArgValue, in order to be safe we must // bump the reference count by one. - optr.inc_ref(); assert!(optr.count() >= 1); optr.downcast() } @@ -458,32 +458,32 @@ mod tests { Ok(()) } - #[test] - fn roundtrip_argvalue() -> Result<()> { - let ptr = ObjectPtr::new(Object::base::()); - assert_eq!(ptr.count(), 1); - let ptr_clone = ptr.clone(); - assert_eq!(ptr.count(), 2); - let arg_value: ArgValue = (&ptr_clone).into(); - assert_eq!(ptr.count(), 2); - let ptr2: ObjectPtr = arg_value.try_into()?; - assert_eq!(ptr2.count(), 3); - assert_eq!(ptr.count(), ptr2.count()); - drop(ptr_clone); - assert_eq!(ptr.count(), 2); - ensure!( - ptr.type_index == ptr2.type_index, - "type indices do not match" - ); - ensure!( - ptr.fdeleter == ptr2.fdeleter, - "objects have different deleters" - ); - // After dropping the second pointer we should only see only refcount. - drop(ptr2); - assert_eq!(ptr.count(), 1); - Ok(()) - } + // #[test] + // fn roundtrip_argvalue() -> Result<()> { + // let ptr = ObjectPtr::new(Object::base::()); + // assert_eq!(ptr.count(), 1); + // let ptr_clone = ptr.clone(); + // assert_eq!(ptr.count(), 2); + // let arg_value: ArgValue = (&ptr_clone).into(); + // assert_eq!(ptr.count(), 2); + // let ptr2: ObjectPtr = arg_value.try_into()?; + // assert_eq!(ptr2.count(), 3); + // assert_eq!(ptr.count(), ptr2.count()); + // drop(ptr_clone); + // assert_eq!(ptr.count(), 2); + // ensure!( + // ptr.type_index == ptr2.type_index, + // "type indices do not match" + // ); + // ensure!( + // ptr.fdeleter == ptr2.fdeleter, + // "objects have different deleters" + // ); + // // After dropping the second pointer we should only see only refcount. + // drop(ptr2); + // assert_eq!(ptr.count(), 1); + // Ok(()) + // } fn test_fn_raw<'a>( mut args: crate::to_function::ArgList<'a>, @@ -537,15 +537,16 @@ mod tests { let ptr = ObjectPtr::new(Object::base::()); // Call the function without the wrapping for TVM. assert_eq!(ptr.count(), 1); - let same = test_fn_raw(vec![(&ptr).into(), (&ptr).into()]).unwrap(); - let output: ObjectPtr = same.try_into().unwrap(); + let output = test_fn_typed(ptr.clone(), ptr.clone()); assert_eq!(output.count(), 2); drop(output); assert_eq!(ptr.count(), 1); register(test_fn_typed, "test_fn_typed").unwrap(); - let raw_func = Function::get("test_fn_typed").unwrap(); - let output = raw_func.invoke(vec![(&ptr).into(), (&ptr).into()]).unwrap(); + let typed_func = Function::get("test_fn_typed").unwrap(); + let output = typed_func + .invoke(vec![(&ptr).into(), (&ptr).into()]) + .unwrap(); let output: ObjectPtr = output.try_into().unwrap(); assert_eq!(output.count(), 2); drop(output); diff --git a/rust/tvm-rt/src/to_function.rs b/rust/tvm-rt/src/to_function.rs index 523eb70f1598..dbf8e6f23559 100644 --- a/rust/tvm-rt/src/to_function.rs +++ b/rust/tvm-rt/src/to_function.rs @@ -139,6 +139,11 @@ pub trait ToFunction: Sized { } }; + match rv.clone().try_into() as Result> { + Err(e) => {} + Ok(v) => drop(v), + }; + let (mut ret_val, ret_tcode) = rv.to_tvm_value(); let mut ret_type_code = ret_tcode as c_int; From 4912a52463165b105d4cd7a6b7a5fdc0835456e4 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Thu, 19 Aug 2021 10:01:49 -0700 Subject: [PATCH 09/13] Follow up work from the memory leak, attempt to clean up ByteArray --- rust/tvm-rt/src/graph_rt.rs | 1 + rust/tvm-rt/src/object/object_ptr.rs | 54 +++++++++---------- rust/tvm-sys/src/byte_array.rs | 79 +++++++++++++++++++--------- rust/tvm/examples/resnet/src/main.rs | 46 ++++++++++------ 4 files changed, 113 insertions(+), 67 deletions(-) diff --git a/rust/tvm-rt/src/graph_rt.rs b/rust/tvm-rt/src/graph_rt.rs index 5ac9710424e0..53f3210aa742 100644 --- a/rust/tvm-rt/src/graph_rt.rs +++ b/rust/tvm-rt/src/graph_rt.rs @@ -55,6 +55,7 @@ impl GraphRt { // NOTE you must pass the device id in as i32 because that's what TVM expects (dev.device_id as i32).into(), ]); + let graph_executor_module: Module = runtime_create_fn_ret?.try_into()?; Ok(Self { module: graph_executor_module, diff --git a/rust/tvm-rt/src/object/object_ptr.rs b/rust/tvm-rt/src/object/object_ptr.rs index 58a4040025b9..09d6068f1a88 100644 --- a/rust/tvm-rt/src/object/object_ptr.rs +++ b/rust/tvm-rt/src/object/object_ptr.rs @@ -229,7 +229,7 @@ impl ObjectPtr { // ABI compatible atomics is funky/hard. self.as_ref() .ref_count - .load(std::sync::atomic::Ordering::Acquire) + .load(std::sync::atomic::Ordering::Relaxed) } /// This method avoid running the destructor on self once it's dropped, so we don't accidentally release the memory @@ -458,32 +458,32 @@ mod tests { Ok(()) } - // #[test] - // fn roundtrip_argvalue() -> Result<()> { - // let ptr = ObjectPtr::new(Object::base::()); - // assert_eq!(ptr.count(), 1); - // let ptr_clone = ptr.clone(); - // assert_eq!(ptr.count(), 2); - // let arg_value: ArgValue = (&ptr_clone).into(); - // assert_eq!(ptr.count(), 2); - // let ptr2: ObjectPtr = arg_value.try_into()?; - // assert_eq!(ptr2.count(), 3); - // assert_eq!(ptr.count(), ptr2.count()); - // drop(ptr_clone); - // assert_eq!(ptr.count(), 2); - // ensure!( - // ptr.type_index == ptr2.type_index, - // "type indices do not match" - // ); - // ensure!( - // ptr.fdeleter == ptr2.fdeleter, - // "objects have different deleters" - // ); - // // After dropping the second pointer we should only see only refcount. - // drop(ptr2); - // assert_eq!(ptr.count(), 1); - // Ok(()) - // } + #[test] + fn roundtrip_argvalue() -> Result<()> { + let ptr = ObjectPtr::new(Object::base::()); + assert_eq!(ptr.count(), 1); + let ptr_clone = ptr.clone(); + assert_eq!(ptr.count(), 2); + let arg_value: ArgValue = (&ptr_clone).into(); + assert_eq!(ptr.count(), 2); + let ptr2: ObjectPtr = arg_value.try_into()?; + assert_eq!(ptr2.count(), 3); + assert_eq!(ptr.count(), ptr2.count()); + drop(ptr_clone); + assert_eq!(ptr.count(), 2); + ensure!( + ptr.type_index == ptr2.type_index, + "type indices do not match" + ); + ensure!( + ptr.fdeleter == ptr2.fdeleter, + "objects have different deleters" + ); + // After dropping the second pointer we should only see only refcount. + drop(ptr2); + assert_eq!(ptr.count(), 1); + Ok(()) + } fn test_fn_raw<'a>( mut args: crate::to_function::ArgList<'a>, diff --git a/rust/tvm-sys/src/byte_array.rs b/rust/tvm-sys/src/byte_array.rs index f4341aebe762..dfc6ab5edf1f 100644 --- a/rust/tvm-sys/src/byte_array.rs +++ b/rust/tvm-sys/src/byte_array.rs @@ -17,7 +17,6 @@ * under the License. */ use std::convert::TryFrom; -use std::os::raw::c_char; use crate::errors::ValueDowncastError; use crate::ffi::{TVMByteArray, TVMByteArrayFree}; @@ -33,24 +32,45 @@ use crate::{ArgValue, RetValue}; /// assert_eq!(barr.len(), v.len()); /// assert_eq!(barr.data(), &[104u8, 101, 108, 108, 111]); /// ``` -pub struct ByteArray { - /// The raw FFI ByteArray. - array: TVMByteArray, +pub enum ByteArray { + Rust(TVMByteArray), + External(TVMByteArray), } impl Drop for ByteArray { - fn drop(&mut self) {} + fn drop(&mut self) { + match self { + ByteArray::Rust(bytes) => { + let ptr = bytes.data; + let len = bytes.size as _; + let cap = bytes.size as _; + let data: Vec = unsafe { Vec::from_raw_parts(ptr as _, len, cap) }; + drop(data); + }, + ByteArray::External(byte_array) => { + unsafe { if TVMByteArrayFree(byte_array as _) != 0 { + panic!("error"); + } } + } + } + } } impl ByteArray { /// Gets the underlying byte-array - pub fn data(&self) -> &'static [u8] { - unsafe { std::slice::from_raw_parts(self.array.data as *const u8, self.array.size as _) } + pub fn data(&self) -> &[u8] { + match self { + ByteArray::Rust(byte_array) | ByteArray::External(byte_array) => { + unsafe { std::slice::from_raw_parts(byte_array.data as *const u8, byte_array.size as _) } + } + } } /// Gets the length of the underlying byte-array pub fn len(&self) -> usize { - self.array.size as _ + match self { + ByteArray::Rust(byte_array) | ByteArray::External(byte_array) => byte_array.size as _, + } } /// Converts the underlying byte-array to `Vec` @@ -63,36 +83,47 @@ impl ByteArray { } } -// Needs AsRef for Vec -impl> From for ByteArray { +impl>> From for ByteArray { fn from(arg: T) -> Self { - let arg = arg.as_ref(); - ByteArray { - array: TVMByteArray { - data: arg.as_ptr() as *const c_char, - size: arg.len() as _, - }, - } + + let mut incoming_bytes: Vec = arg.into(); + let mut bytes = Vec::with_capacity(incoming_bytes.len()); + bytes.append(&mut incoming_bytes); + + let mut bytes = std::mem::ManuallyDrop::new(bytes); + let ptr = bytes.as_mut_ptr(); + assert_eq!(bytes.len(), bytes.capacity()); + ByteArray::Rust(TVMByteArray { data: ptr as _, size: bytes.len() as _ }) } } impl<'a> From<&'a ByteArray> for ArgValue<'a> { fn from(val: &'a ByteArray) -> ArgValue<'a> { - ArgValue::Bytes(&val.array) + match val { + ByteArray::Rust(byte_array) | ByteArray::External(byte_array) => { + ArgValue::Bytes(byte_array) + } + } } } -impl From for RetValue { - fn from(val: ByteArray) -> RetValue { - RetValue::Bytes(val.array) - } -} +// TODO(@jroesch): This requires a little more work, going to land narratives +// impl From for RetValue { +// fn from(val: ByteArray) -> RetValue { +// // match val { +// // ByteArray::Rust(byte_array) | ByteArray::External(byte_array) => { +// // ArgValue::Bytes(byte_array) +// // } +// // } +// panic!("need to audit the lifetimes of this code"); +// } +// } impl TryFrom for ByteArray { type Error = ValueDowncastError; fn try_from(val: RetValue) -> Result { match val { - RetValue::Bytes(array) => Ok(ByteArray { array }), + RetValue::Bytes(array) => Ok(ByteArray::External(array)), _ => Err(ValueDowncastError { expected_type: "ByteArray", actual_type: format!("{:?}", val), diff --git a/rust/tvm/examples/resnet/src/main.rs b/rust/tvm/examples/resnet/src/main.rs index 22933e0cc5af..c22d55f2e4da 100644 --- a/rust/tvm/examples/resnet/src/main.rs +++ b/rust/tvm/examples/resnet/src/main.rs @@ -82,22 +82,36 @@ fn main() -> anyhow::Result<()> { let params: Vec = fs::read(concat!(env!("CARGO_MANIFEST_DIR"), "/deploy_param.params"))?; println!("param bytes: {}", params.len()); - let mut output: Vec; - let mut graph_rt = GraphRt::create_from_parts(&graph, lib.clone(), dev)?; - graph_rt.load_params(¶ms)?; - - loop { - graph_rt.set_input("data", input.clone())?; - graph_rt.run()?; - - // prepare to get the output - let output_shape = &[1, 1000]; - let output_nd = NDArray::empty(output_shape, Device::cpu(0), DataType::float(32, 1)); - graph_rt.get_output_into(0, output_nd.clone())?; - - // flatten the output as Vec - output = output_nd.to_vec::()?; - } + // If you want an easy way to test a memory leak simply replace the program below with: + // let mut output: Vec; + + // loop { + // let mut graph_rt = GraphRt::create_from_parts(&graph, lib.clone(), dev)?; + // graph_rt.load_params(params.clone())?; + // graph_rt.set_input("data", input.clone())?; + // graph_rt.run()?; + + // // prepare to get the output + // let output_shape = &[1, 1000]; + // let output_nd = NDArray::empty(output_shape, Device::cpu(0), DataType::float(32, 1)); + // graph_rt.get_output_into(0, output_nd.clone())?; + + // // flatten the output as Vec + // output = output_nd.to_vec::()?; + // } + + let mut graph_rt = GraphRt::create_from_parts(&graph, lib, dev)?; + graph_rt.load_params(params)?; + graph_rt.set_input("data", input)?; + graph_rt.run()?; + + // prepare to get the output + let output_shape = &[1, 1000]; + let output_nd = NDArray::empty(output_shape, Device::cpu(0), DataType::float(32, 1)); + graph_rt.get_output_into(0, output_nd.clone())?; + + // flatten the output as Vec + let output: Vec = output_nd.to_vec::()?; // find the maximum entry in the output and its index let (argmax, max_prob) = output From 53dac41a9f7ffa5798ca211e19166688887d1b0d Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Thu, 19 Aug 2021 10:43:10 -0700 Subject: [PATCH 10/13] Add some todos for future work --- rust/tvm-rt/src/lib.rs | 21 ++++++++-------- rust/tvm-rt/src/to_function.rs | 3 ++- rust/tvm-sys/src/byte_array.rs | 44 ++++++++++++++++++---------------- 3 files changed, 36 insertions(+), 32 deletions(-) diff --git a/rust/tvm-rt/src/lib.rs b/rust/tvm-rt/src/lib.rs index 824dc63f0b50..3b7d066e7b78 100644 --- a/rust/tvm-rt/src/lib.rs +++ b/rust/tvm-rt/src/lib.rs @@ -130,16 +130,17 @@ mod tests { ); } - #[test] - fn bytearray() { - let w = vec![1u8, 2, 3, 4, 5]; - let v = ByteArray::from(w.as_slice()); - let tvm: ByteArray = RetValue::from(v).try_into().unwrap(); - assert_eq!( - tvm.data(), - w.iter().copied().collect::>().as_slice() - ); - } + // todo(@jroesch): #8800 Follow up with ByteArray RetValue ownership. + // #[test] + // fn bytearray() { + // let w = vec![1u8, 2, 3, 4, 5]; + // let v = ByteArray::from(w.as_slice()); + // let tvm: ByteArray = RetValue::from(v).try_into().unwrap(); + // assert_eq!( + // tvm.data(), + // w.iter().copied().collect::>().as_slice() + // ); + // } #[test] fn ty() { diff --git a/rust/tvm-rt/src/to_function.rs b/rust/tvm-rt/src/to_function.rs index dbf8e6f23559..67fbfc996af0 100644 --- a/rust/tvm-rt/src/to_function.rs +++ b/rust/tvm-rt/src/to_function.rs @@ -139,8 +139,9 @@ pub trait ToFunction: Sized { } }; + // TODO(@jroesch): clean up the handling of the is dec_ref match rv.clone().try_into() as Result> { - Err(e) => {} + Err(_) => {} Ok(v) => drop(v), }; diff --git a/rust/tvm-sys/src/byte_array.rs b/rust/tvm-sys/src/byte_array.rs index dfc6ab5edf1f..2903a81d9c36 100644 --- a/rust/tvm-sys/src/byte_array.rs +++ b/rust/tvm-sys/src/byte_array.rs @@ -46,12 +46,12 @@ impl Drop for ByteArray { let cap = bytes.size as _; let data: Vec = unsafe { Vec::from_raw_parts(ptr as _, len, cap) }; drop(data); - }, - ByteArray::External(byte_array) => { - unsafe { if TVMByteArrayFree(byte_array as _) != 0 { - panic!("error"); - } } } + ByteArray::External(byte_array) => unsafe { + if TVMByteArrayFree(byte_array as _) != 0 { + panic!("error"); + } + }, } } } @@ -60,16 +60,16 @@ impl ByteArray { /// Gets the underlying byte-array pub fn data(&self) -> &[u8] { match self { - ByteArray::Rust(byte_array) | ByteArray::External(byte_array) => { - unsafe { std::slice::from_raw_parts(byte_array.data as *const u8, byte_array.size as _) } - } + ByteArray::Rust(byte_array) | ByteArray::External(byte_array) => unsafe { + std::slice::from_raw_parts(byte_array.data as *const u8, byte_array.size as _) + }, } } /// Gets the length of the underlying byte-array pub fn len(&self) -> usize { match self { - ByteArray::Rust(byte_array) | ByteArray::External(byte_array) => byte_array.size as _, + ByteArray::Rust(byte_array) | ByteArray::External(byte_array) => byte_array.size as _, } } @@ -85,7 +85,6 @@ impl ByteArray { impl>> From for ByteArray { fn from(arg: T) -> Self { - let mut incoming_bytes: Vec = arg.into(); let mut bytes = Vec::with_capacity(incoming_bytes.len()); bytes.append(&mut incoming_bytes); @@ -93,29 +92,32 @@ impl>> From for ByteArray { let mut bytes = std::mem::ManuallyDrop::new(bytes); let ptr = bytes.as_mut_ptr(); assert_eq!(bytes.len(), bytes.capacity()); - ByteArray::Rust(TVMByteArray { data: ptr as _, size: bytes.len() as _ }) + ByteArray::Rust(TVMByteArray { + data: ptr as _, + size: bytes.len() as _, + }) } } impl<'a> From<&'a ByteArray> for ArgValue<'a> { fn from(val: &'a ByteArray) -> ArgValue<'a> { match val { - ByteArray::Rust(byte_array) | ByteArray::External(byte_array) => { + ByteArray::Rust(byte_array) | ByteArray::External(byte_array) => { ArgValue::Bytes(byte_array) } } } } -// TODO(@jroesch): This requires a little more work, going to land narratives +// todo(@jroesch): #8800 Follow up with ByteArray RetValue ownership. // impl From for RetValue { // fn from(val: ByteArray) -> RetValue { -// // match val { -// // ByteArray::Rust(byte_array) | ByteArray::External(byte_array) => { -// // ArgValue::Bytes(byte_array) -// // } -// // } -// panic!("need to audit the lifetimes of this code"); +// match val { +// ByteArray::Rust(byte_array) | ByteArray::External(byte_array) => { +// // TODO(@jroesch): This requires a little more work, going to land narratives +// RetValue::Bytes(byte_array) +// } +// } // } // } @@ -139,11 +141,11 @@ mod tests { #[test] fn convert() { let v = vec![1u8, 2, 3]; - let barr = ByteArray::from(&v); + let barr = ByteArray::from(v.to_vec()); assert_eq!(barr.len(), v.len()); assert_eq!(barr.to_vec(), vec![1u8, 2, 3]); let v = b"hello"; - let barr = ByteArray::from(&v); + let barr = ByteArray::from(v.to_vec()); assert_eq!(barr.len(), v.len()); assert_eq!(barr.data(), &[104u8, 101, 108, 108, 111]); } From 1a3e1e502091cff587831962cb25a677608bebd8 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Thu, 19 Aug 2021 17:08:45 -0700 Subject: [PATCH 11/13] Fix doc string --- include/tvm/runtime/c_runtime_api.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/include/tvm/runtime/c_runtime_api.h b/include/tvm/runtime/c_runtime_api.h index 1039590b34a8..8454b04443a1 100644 --- a/include/tvm/runtime/c_runtime_api.h +++ b/include/tvm/runtime/c_runtime_api.h @@ -521,9 +521,9 @@ TVM_DLL int TVMObjectGetTypeIndex(TVMObjectHandle obj, unsigned* out_tindex); TVM_DLL int TVMObjectTypeKey2Index(const char* type_key, unsigned* out_tindex); /*! - * \brief Convert type key to type index. - * \param type_key The key of the type. - * \param out_tindex the corresponding type index. + * \brief Convert type index to type key. + * \param tindex The type index. + * \param out_type_key The output type key. * \return 0 when success, nonzero when failure happens */ TVM_DLL int TVMObjectTypeIndex2Key(unsigned tindex, char** out_type_key); From 5a90e2f9fd83a825226eb7f088a9e17da91c6004 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Fri, 20 Aug 2021 15:23:34 -0700 Subject: [PATCH 12/13] Clean up the changes --- rust/tvm-rt/src/function.rs | 11 +++-------- rust/tvm-rt/src/map.rs | 37 ------------------------------------- 2 files changed, 3 insertions(+), 45 deletions(-) diff --git a/rust/tvm-rt/src/function.rs b/rust/tvm-rt/src/function.rs index 450aba578377..97fc497ad629 100644 --- a/rust/tvm-rt/src/function.rs +++ b/rust/tvm-rt/src/function.rs @@ -115,7 +115,6 @@ impl Function { pub fn invoke<'a>(&self, arg_buf: Vec>) -> Result { let num_args = arg_buf.len(); let (mut values, mut type_codes): (Vec, Vec) = arg_buf - .clone() .into_iter() .map(|arg| arg.to_tvm_value()) .unzip(); @@ -134,10 +133,6 @@ impl Function { ) }; - if ret_type_code == crate::ffi::TVMArgTypeCode_kTVMObjectRValueRefArg as _ { - panic!() - } - if ret_code != 0 { let raw_error = crate::get_last_error(); let error = match Error::from_raw_tvm(raw_error) { @@ -204,8 +199,8 @@ impl TryFrom for Function { } } -impl<'a> From<&Function> for ArgValue<'a> { - fn from(func: &Function) -> ArgValue<'a> { +impl<'a> From<&'a Function> for ArgValue<'a> { + fn from(func: &'a Function) -> ArgValue<'a> { if func.handle().is_null() { ArgValue::Null } else { @@ -303,7 +298,7 @@ pub fn register_untyped>( name: S, override_: bool, ) -> Result<()> { - // // TODO(@jroesch): can we unify all the code. + //TODO(@jroesch): can we unify the untpyed and typed registration functions. let func = ToFunction::::to_function(f); let name = name.into(); // Not sure about this code diff --git a/rust/tvm-rt/src/map.rs b/rust/tvm-rt/src/map.rs index 8ae5d06a28bf..5594a91dc0f0 100644 --- a/rust/tvm-rt/src/map.rs +++ b/rust/tvm-rt/src/map.rs @@ -121,43 +121,6 @@ where } } -// pub struct Iter<'a, K, V> { -// // NB: due to FFI this isn't as lazy as one might like -// key_and_values: Array, -// next_key: i64, -// _data: PhantomData<(&'a K, &'a V)>, -// } - -// impl<'a, K, V> Iterator for Iter<'a, K, V> -// where -// K: IsObjectRef, -// V: IsObjectRef, -// { -// type Item = (&'a K, &'a V); - -// #[inline] -// fn next(&mut self) -> Option<(&'a K, &'a V)> { -// if self.next_key < self.key_and_values.len() { -// let key = self -// .key_and_values -// .get(self.next_key as isize) -// .expect("this should always succeed"); -// let value = self -// .key_and_values -// .get((self.next_key as isize) + 1) -// .expect("this should always succeed"); -// self.next_key += 2; -// Some((key.downcast().unwrap(), value.downcast().unwrap())) -// } else { -// None -// } -// } - -// #[inline] -// fn size_hint(&self) -> (usize, Option) { -// ((self.key_and_values.len() / 2) as usize, None) -// } -// } pub struct IntoIter { // NB: due to FFI this isn't as lazy as one might like key_and_values: Array, From b6ecf60a1f72d42f47626a8317f18c3e4966a798 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Fri, 20 Aug 2021 17:31:52 -0700 Subject: [PATCH 13/13] Format --- rust/tvm-rt/src/function.rs | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/rust/tvm-rt/src/function.rs b/rust/tvm-rt/src/function.rs index 97fc497ad629..62474e6650d4 100644 --- a/rust/tvm-rt/src/function.rs +++ b/rust/tvm-rt/src/function.rs @@ -114,10 +114,8 @@ impl Function { /// Calls the function that created from `Builder`. pub fn invoke<'a>(&self, arg_buf: Vec>) -> Result { let num_args = arg_buf.len(); - let (mut values, mut type_codes): (Vec, Vec) = arg_buf - .into_iter() - .map(|arg| arg.to_tvm_value()) - .unzip(); + let (mut values, mut type_codes): (Vec, Vec) = + arg_buf.into_iter().map(|arg| arg.to_tvm_value()).unzip(); let mut ret_val = ffi::TVMValue { v_int64: 0 }; let mut ret_type_code = 0i32;