Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 39 additions & 50 deletions rust/tvm-rt/src/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
//! See the tests and examples repository for more examples.

use std::convert::{TryFrom, TryInto};
use std::sync::Arc;
use std::{
ffi::CString,
os::raw::{c_char, c_int},
Expand All @@ -39,36 +40,43 @@ pub use tvm_sys::{ffi, ArgValue, RetValue};

pub type Result<T> = std::result::Result<T, Error>;

/// Wrapper around TVM function handle which includes `is_global`
/// indicating whether the function is global or not, and `is_cloned` showing
/// not to drop a cloned function from Rust side.
/// The value of these fields can be accessed through their respective methods.
#[derive(Debug, Hash)]
pub struct Function {
pub(crate) handle: ffi::TVMFunctionHandle,
// whether the registered function is global or not.
is_global: bool,
from_rust: bool,
struct FunctionPtr {
handle: ffi::TVMFunctionHandle,
}

unsafe impl Send for Function {}
unsafe impl Sync for Function {}
// NB(@jroesch): I think this is ok, need to double check,
// if not we should mutex the pointer or move to Rc.
unsafe impl Send for FunctionPtr {}
unsafe impl Sync for FunctionPtr {}

impl FunctionPtr {
fn from_raw(handle: ffi::TVMFunctionHandle) -> Self {
FunctionPtr { handle }
}
}

impl Drop for FunctionPtr {
fn drop(&mut self) {
check_call!(ffi::TVMFuncFree(self.handle));
}
}

/// An owned thread-safe version of `tvm::PackedFunc` for consumption in Rust.
#[derive(Debug, Hash)]
pub struct Function {
inner: Arc<FunctionPtr>,
}

impl Function {
pub(crate) fn new(handle: ffi::TVMFunctionHandle) -> Self {
pub(crate) fn from_raw(handle: ffi::TVMFunctionHandle) -> Self {
Function {
handle,
is_global: false,
from_rust: false,
inner: Arc::new(FunctionPtr::from_raw(handle)),
}
}

pub unsafe fn null() -> Self {
Function {
handle: std::ptr::null_mut(),
is_global: false,
from_rust: false,
}
Function::from_raw(std::ptr::null_mut())
}

/// For a given function, it returns a function by name.
Expand All @@ -84,11 +92,7 @@ impl Function {
if handle.is_null() {
None
} else {
Some(Function {
handle,
is_global: true,
from_rust: false,
})
Some(Function::from_raw(handle))
}
}

Expand All @@ -103,12 +107,7 @@ impl Function {

/// Returns the underlying TVM function handle.
pub fn handle(&self) -> ffi::TVMFunctionHandle {
self.handle
}

/// Returns `true` if the underlying TVM function is global and `false` otherwise.
pub fn is_global(&self) -> bool {
self.is_global
self.inner.handle
}

/// Calls the function that created from `Builder`.
Expand All @@ -122,7 +121,7 @@ impl Function {

let ret_code = unsafe {
ffi::TVMFuncCall(
self.handle,
self.handle(),
values.as_mut_ptr() as *mut ffi::TVMValue,
type_codes.as_mut_ptr() as *mut c_int,
num_args as c_int,
Expand Down Expand Up @@ -171,25 +170,15 @@ impl_to_fn!(T1, T2, T3, T4, T5, T6,);

impl Clone for Function {
fn clone(&self) -> Function {
Self {
handle: self.handle,
is_global: self.is_global,
from_rust: true,
Function {
inner: self.inner.clone(),
}
}
}

// impl Drop for Function {
// fn drop(&mut self) {
// if !self.is_global && !self.is_cloned {
// check_call!(ffi::TVMFuncFree(self.handle));
// }
// }
// }

impl From<Function> for RetValue {
fn from(func: Function) -> RetValue {
RetValue::FuncHandle(func.handle)
RetValue::FuncHandle(func.handle())
}
}

Expand All @@ -198,7 +187,7 @@ impl TryFrom<RetValue> for Function {

fn try_from(ret_value: RetValue) -> Result<Function> {
match ret_value {
RetValue::FuncHandle(handle) => Ok(Function::new(handle)),
RetValue::FuncHandle(handle) => Ok(Function::from_raw(handle)),
_ => Err(Error::downcast(
format!("{:?}", ret_value),
"FunctionHandle",
Expand All @@ -209,10 +198,10 @@ impl TryFrom<RetValue> for Function {

impl<'a> From<Function> for ArgValue<'a> {
fn from(func: Function) -> ArgValue<'a> {
if func.handle.is_null() {
if func.handle().is_null() {
ArgValue::Null
} else {
ArgValue::FuncHandle(func.handle)
ArgValue::FuncHandle(func.handle())
}
}
}
Expand All @@ -222,7 +211,7 @@ impl<'a> TryFrom<ArgValue<'a>> for Function {

fn try_from(arg_value: ArgValue<'a>) -> Result<Function> {
match arg_value {
ArgValue::FuncHandle(handle) => Ok(Function::new(handle)),
ArgValue::FuncHandle(handle) => Ok(Function::from_raw(handle)),
_ => Err(Error::downcast(
format!("{:?}", arg_value),
"FunctionHandle",
Expand All @@ -236,7 +225,7 @@ impl<'a> TryFrom<&ArgValue<'a>> for Function {

fn try_from(arg_value: &ArgValue<'a>) -> Result<Function> {
match arg_value {
ArgValue::FuncHandle(handle) => Ok(Function::new(*handle)),
ArgValue::FuncHandle(handle) => Ok(Function::from_raw(*handle)),
_ => Err(Error::downcast(
format!("{:?}", arg_value),
"FunctionHandle",
Expand Down
2 changes: 1 addition & 1 deletion rust/tvm-rt/src/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ impl Module {
return Err(errors::Error::NullHandle(name.into_string()?.to_string()));
}

Ok(Function::new(fhandle))
Ok(Function::from_raw(fhandle))
}

/// Imports a dependent module such as `.ptx` for cuda gpu.
Expand Down
4 changes: 2 additions & 2 deletions rust/tvm-rt/src/ndarray.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ use num_traits::Num;

use crate::errors::NDArrayError;

use crate::object::{Object, ObjectPtr};
use crate::object::{Object, ObjectPtr, ObjectRef};

/// See the [`module-level documentation`](../ndarray/index.html) for more details.
#[repr(C)]
Expand All @@ -73,7 +73,7 @@ pub struct NDArrayContainer {
// Container Base
dl_tensor: DLTensor,
manager_ctx: *mut c_void,
// TOOD: shape?
shape: ObjectRef,
}

impl NDArrayContainer {
Expand Down
12 changes: 0 additions & 12 deletions rust/tvm-rt/src/object/object_ptr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,18 +148,6 @@ impl Object {
}
}

// impl fmt::Debug for Object {
// fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
// let index =
// format!("{} // key: {}", self.type_index, "the_key");

// f.debug_struct("Object")
// .field("type_index", &index)
// // TODO(@jroesch: do we expose other fields?)
// .finish()
// }
// }

/// An unsafe trait which should be implemented for an object
/// subtype.
///
Expand Down
2 changes: 1 addition & 1 deletion rust/tvm-rt/src/to_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ pub trait ToFunction<I, O>: Sized {
&mut fhandle as *mut ffi::TVMFunctionHandle,
));

Function::new(fhandle)
Function::from_raw(fhandle)
}

/// The callback function which is wrapped converted by TVM
Expand Down
10 changes: 8 additions & 2 deletions rust/tvm-sys/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@

extern crate bindgen;

use std::{path::{Path, PathBuf}, str::FromStr};
use std::{
path::{Path, PathBuf},
str::FromStr,
};

use anyhow::{Context, Result};
use tvm_build::{BuildConfig, CMakeSetting};
Expand Down Expand Up @@ -195,7 +198,10 @@ fn find_using_tvm_build() -> Result<TVMInstall> {
if cfg!(feature = "use-vitis-ai") {
build_config.settings.use_vitis_ai = Some(true);
}
if cfg!(any(feature = "static-linking", feature = "build-static-runtime")) {
if cfg!(any(
feature = "static-linking",
feature = "build-static-runtime"
)) {
build_config.settings.build_static_runtime = Some(true);
}

Expand Down
2 changes: 1 addition & 1 deletion rust/tvm/tests/basics/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ fn main() {
let mut arr = NDArray::empty(shape, dev, dtype);
arr.copy_from_buffer(data.as_mut_slice());
let ret = NDArray::empty(shape, dev, dtype);
let mut fadd = Module::load(&concat!(env!("OUT_DIR"), "/test_add.so")).unwrap();
let fadd = Module::load(&concat!(env!("OUT_DIR"), "/test_add.so")).unwrap();
if !fadd.enabled(dev_name) {
return;
}
Expand Down