diff --git a/rust/tvm-rt/src/function.rs b/rust/tvm-rt/src/function.rs index aec4a8ad44de..5db665cc7a48 100644 --- a/rust/tvm-rt/src/function.rs +++ b/rust/tvm-rt/src/function.rs @@ -26,6 +26,7 @@ //! See the tests and examples repository for more examples. use std::convert::{TryFrom, TryInto}; +use std::sync::Arc; use std::{ ffi::CString, os::raw::{c_char, c_int}, @@ -39,36 +40,43 @@ pub use tvm_sys::{ffi, ArgValue, RetValue}; pub type Result = std::result::Result; -/// Wrapper around TVM function handle which includes `is_global` -/// indicating whether the function is global or not, and `is_cloned` showing -/// not to drop a cloned function from Rust side. -/// The value of these fields can be accessed through their respective methods. #[derive(Debug, Hash)] -pub struct Function { - pub(crate) handle: ffi::TVMFunctionHandle, - // whether the registered function is global or not. - is_global: bool, - from_rust: bool, +struct FunctionPtr { + handle: ffi::TVMFunctionHandle, } -unsafe impl Send for Function {} -unsafe impl Sync for Function {} +// NB(@jroesch): I think this is ok, need to double check, +// if not we should mutex the pointer or move to Rc. +unsafe impl Send for FunctionPtr {} +unsafe impl Sync for FunctionPtr {} + +impl FunctionPtr { + fn from_raw(handle: ffi::TVMFunctionHandle) -> Self { + FunctionPtr { handle } + } +} + +impl Drop for FunctionPtr { + fn drop(&mut self) { + check_call!(ffi::TVMFuncFree(self.handle)); + } +} + +/// An owned thread-safe version of `tvm::PackedFunc` for consumption in Rust. +#[derive(Debug, Hash)] +pub struct Function { + inner: Arc, +} impl Function { - pub(crate) fn new(handle: ffi::TVMFunctionHandle) -> Self { + pub(crate) fn from_raw(handle: ffi::TVMFunctionHandle) -> Self { Function { - handle, - is_global: false, - from_rust: false, + inner: Arc::new(FunctionPtr::from_raw(handle)), } } pub unsafe fn null() -> Self { - Function { - handle: std::ptr::null_mut(), - is_global: false, - from_rust: false, - } + Function::from_raw(std::ptr::null_mut()) } /// For a given function, it returns a function by name. @@ -84,11 +92,7 @@ impl Function { if handle.is_null() { None } else { - Some(Function { - handle, - is_global: true, - from_rust: false, - }) + Some(Function::from_raw(handle)) } } @@ -103,12 +107,7 @@ impl Function { /// Returns the underlying TVM function handle. pub fn handle(&self) -> ffi::TVMFunctionHandle { - self.handle - } - - /// Returns `true` if the underlying TVM function is global and `false` otherwise. - pub fn is_global(&self) -> bool { - self.is_global + self.inner.handle } /// Calls the function that created from `Builder`. @@ -122,7 +121,7 @@ impl Function { let ret_code = unsafe { ffi::TVMFuncCall( - self.handle, + self.handle(), values.as_mut_ptr() as *mut ffi::TVMValue, type_codes.as_mut_ptr() as *mut c_int, num_args as c_int, @@ -171,25 +170,15 @@ impl_to_fn!(T1, T2, T3, T4, T5, T6,); impl Clone for Function { fn clone(&self) -> Function { - Self { - handle: self.handle, - is_global: self.is_global, - from_rust: true, + Function { + inner: self.inner.clone(), } } } -// impl Drop for Function { -// fn drop(&mut self) { -// if !self.is_global && !self.is_cloned { -// check_call!(ffi::TVMFuncFree(self.handle)); -// } -// } -// } - impl From for RetValue { fn from(func: Function) -> RetValue { - RetValue::FuncHandle(func.handle) + RetValue::FuncHandle(func.handle()) } } @@ -198,7 +187,7 @@ impl TryFrom for Function { fn try_from(ret_value: RetValue) -> Result { match ret_value { - RetValue::FuncHandle(handle) => Ok(Function::new(handle)), + RetValue::FuncHandle(handle) => Ok(Function::from_raw(handle)), _ => Err(Error::downcast( format!("{:?}", ret_value), "FunctionHandle", @@ -209,10 +198,10 @@ impl TryFrom for Function { impl<'a> From for ArgValue<'a> { fn from(func: Function) -> ArgValue<'a> { - if func.handle.is_null() { + if func.handle().is_null() { ArgValue::Null } else { - ArgValue::FuncHandle(func.handle) + ArgValue::FuncHandle(func.handle()) } } } @@ -222,7 +211,7 @@ impl<'a> TryFrom> for Function { fn try_from(arg_value: ArgValue<'a>) -> Result { match arg_value { - ArgValue::FuncHandle(handle) => Ok(Function::new(handle)), + ArgValue::FuncHandle(handle) => Ok(Function::from_raw(handle)), _ => Err(Error::downcast( format!("{:?}", arg_value), "FunctionHandle", @@ -236,7 +225,7 @@ impl<'a> TryFrom<&ArgValue<'a>> for Function { fn try_from(arg_value: &ArgValue<'a>) -> Result { match arg_value { - ArgValue::FuncHandle(handle) => Ok(Function::new(*handle)), + ArgValue::FuncHandle(handle) => Ok(Function::from_raw(*handle)), _ => Err(Error::downcast( format!("{:?}", arg_value), "FunctionHandle", diff --git a/rust/tvm-rt/src/module.rs b/rust/tvm-rt/src/module.rs index 343f0dce8f98..8d59c2a035a9 100644 --- a/rust/tvm-rt/src/module.rs +++ b/rust/tvm-rt/src/module.rs @@ -82,7 +82,7 @@ impl Module { return Err(errors::Error::NullHandle(name.into_string()?.to_string())); } - Ok(Function::new(fhandle)) + Ok(Function::from_raw(fhandle)) } /// Imports a dependent module such as `.ptx` for cuda gpu. diff --git a/rust/tvm-rt/src/ndarray.rs b/rust/tvm-rt/src/ndarray.rs index 0e2d2830615f..08dcfe33f28f 100644 --- a/rust/tvm-rt/src/ndarray.rs +++ b/rust/tvm-rt/src/ndarray.rs @@ -61,7 +61,7 @@ use num_traits::Num; use crate::errors::NDArrayError; -use crate::object::{Object, ObjectPtr}; +use crate::object::{Object, ObjectPtr, ObjectRef}; /// See the [`module-level documentation`](../ndarray/index.html) for more details. #[repr(C)] @@ -73,7 +73,7 @@ pub struct NDArrayContainer { // Container Base dl_tensor: DLTensor, manager_ctx: *mut c_void, - // TOOD: shape? + shape: ObjectRef, } impl NDArrayContainer { diff --git a/rust/tvm-rt/src/object/object_ptr.rs b/rust/tvm-rt/src/object/object_ptr.rs index 64fd6a2218aa..a093cf5fe3ae 100644 --- a/rust/tvm-rt/src/object/object_ptr.rs +++ b/rust/tvm-rt/src/object/object_ptr.rs @@ -148,18 +148,6 @@ impl Object { } } -// impl fmt::Debug for Object { -// fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { -// let index = -// format!("{} // key: {}", self.type_index, "the_key"); - -// f.debug_struct("Object") -// .field("type_index", &index) -// // TODO(@jroesch: do we expose other fields?) -// .finish() -// } -// } - /// An unsafe trait which should be implemented for an object /// subtype. /// diff --git a/rust/tvm-rt/src/to_function.rs b/rust/tvm-rt/src/to_function.rs index c5ede7d224ce..7797d2cd23ff 100644 --- a/rust/tvm-rt/src/to_function.rs +++ b/rust/tvm-rt/src/to_function.rs @@ -74,7 +74,7 @@ pub trait ToFunction: Sized { &mut fhandle as *mut ffi::TVMFunctionHandle, )); - Function::new(fhandle) + Function::from_raw(fhandle) } /// The callback function which is wrapped converted by TVM diff --git a/rust/tvm-sys/build.rs b/rust/tvm-sys/build.rs index 930ee59b7bf9..7793f9f6962e 100644 --- a/rust/tvm-sys/build.rs +++ b/rust/tvm-sys/build.rs @@ -19,7 +19,10 @@ extern crate bindgen; -use std::{path::{Path, PathBuf}, str::FromStr}; +use std::{ + path::{Path, PathBuf}, + str::FromStr, +}; use anyhow::{Context, Result}; use tvm_build::{BuildConfig, CMakeSetting}; @@ -195,7 +198,10 @@ fn find_using_tvm_build() -> Result { if cfg!(feature = "use-vitis-ai") { build_config.settings.use_vitis_ai = Some(true); } - if cfg!(any(feature = "static-linking", feature = "build-static-runtime")) { + if cfg!(any( + feature = "static-linking", + feature = "build-static-runtime" + )) { build_config.settings.build_static_runtime = Some(true); } diff --git a/rust/tvm/tests/basics/src/main.rs b/rust/tvm/tests/basics/src/main.rs index 2e0f5b5255a1..b7c30364f294 100644 --- a/rust/tvm/tests/basics/src/main.rs +++ b/rust/tvm/tests/basics/src/main.rs @@ -35,7 +35,7 @@ fn main() { let mut arr = NDArray::empty(shape, dev, dtype); arr.copy_from_buffer(data.as_mut_slice()); let ret = NDArray::empty(shape, dev, dtype); - let mut fadd = Module::load(&concat!(env!("OUT_DIR"), "/test_add.so")).unwrap(); + let fadd = Module::load(&concat!(env!("OUT_DIR"), "/test_add.so")).unwrap(); if !fadd.enabled(dev_name) { return; }