Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion rust/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ members = [
"tvm-graph-rt",
"tvm-graph-rt/tests/test_tvm_basic",
"tvm-graph-rt/tests/test_tvm_dso",
"tvm-graph-rt/tests/test_wasm32",
"tvm-graph-rt/tests/test_nn",
"compiler-ext",
]
2 changes: 1 addition & 1 deletion rust/tvm-graph-rt/tests/test_wasm32/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ authors = ["TVM Contributors"]
edition = "2018"

[dependencies]
ndarray="0.12"
ndarray = "0.12"
tvm-graph-rt = { path = "../../" }

[build-dependencies]
Expand Down
2 changes: 1 addition & 1 deletion rust/tvm-macros/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,5 @@ proc-macro = true
goblin = "^0.2"
proc-macro2 = "^1.0"
quote = "^1.0"
syn = { version = "1.0.17", features = ["full", "extra-traits"] }
syn = { version = "1.0.48", features = ["full", "parsing", "extra-traits"] }
proc-macro-error = "^1.0"
51 changes: 42 additions & 9 deletions rust/tvm-macros/src/external.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,35 @@
* under the License.
*/
use proc_macro2::Span;
use proc_macro_error::abort;
use quote::quote;
use syn::parse::{Parse, ParseStream, Result};

use syn::{FnArg, Generics, Ident, Lit, Meta, NestedMeta, Pat, ReturnType, TraitItemMethod, Type};
use syn::{
token::Semi, Attribute, FnArg, Generics, Ident, Lit, Meta, NestedMeta, Pat, ReturnType,
Signature, Type, Visibility,
};

struct ExternalItem {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: brief rustdoc? Some comments would help me get the story of what ExternalItem and External are for.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What Greg said

attrs: Vec<Attribute>,
visibility: Visibility,
sig: Signature,
}

impl Parse for ExternalItem {
fn parse(input: ParseStream) -> Result<Self> {
let item = ExternalItem {
attrs: input.call(Attribute::parse_outer)?,
visibility: input.parse()?,
sig: input.parse()?,
};
let _semi: Semi = input.parse()?;
Ok(item)
}
}

struct External {
visibility: Visibility,
tvm_name: String,
ident: Ident,
generics: Generics,
Expand All @@ -32,7 +55,8 @@ struct External {

impl Parse for External {
fn parse(input: ParseStream) -> Result<Self> {
let method: TraitItemMethod = input.parse()?;
let method: ExternalItem = input.parse()?;
let visibility = method.visibility;
assert_eq!(method.attrs.len(), 1);
let sig = method.sig;
let tvm_name = method.attrs[0].parse_meta()?;
Expand All @@ -47,8 +71,7 @@ impl Parse for External {
}
_ => panic!(),
};
assert_eq!(method.default, None);
assert!(method.semi_token != None);

let ident = sig.ident;
let generics = sig.generics;
let inputs = sig
Expand All @@ -60,6 +83,7 @@ impl Parse for External {
let ret_type = sig.output;

Ok(External {
visibility,
tvm_name,
ident,
generics,
Expand Down Expand Up @@ -98,6 +122,7 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let mut items = Vec::new();

for external in &ext_input.externs {
let visibility = &external.visibility;
let name = &external.ident;
let global_name = format!("global_{}", external.ident);
let global_name = Ident::new(&global_name, Span::call_site());
Expand All @@ -109,7 +134,9 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
.iter()
.map(|ty_param| match ty_param {
syn::GenericParam::Type(param) => param.clone(),
_ => panic!(),
_ => abort! { ty_param,
"Only supports type parameters."
},
})
.collect();

Expand All @@ -124,15 +151,21 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let ty: Type = *pat_type.ty.clone();
(ident, ty)
}
_ => panic!(),
_ => abort! { pat_type,
"Only supports type parameters."
},
},
pat => abort! {
pat, "invalid pattern type for function";

note = "{:?} is not allowed here", pat;
},
_ => panic!(),
})
.unzip();

let ret_type = match &external.ret_type {
ReturnType::Type(_, rtype) => *rtype.clone(),
_ => panic!(),
ReturnType::Default => syn::parse_str::<Type>("()").unwrap(),
};

let global = quote! {
Expand All @@ -147,7 +180,7 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
items.push(global);

let wrapper = quote! {
pub fn #name<#(#ty_params),*>(#(#args : #tys),*) -> #result_type<#ret_type> {
#visibility fn #name<#(#ty_params),*>(#(#args : #tys),*) -> #result_type<#ret_type> {
let func_ref: #tvm_rt_crate::Function = #global_name.clone();
let func_ref: Box<dyn Fn(#(#tys),*) -> #result_type<#ret_type>> = func_ref.into();
let res: #ret_type = func_ref(#(#args),*)?;
Expand Down
5 changes: 4 additions & 1 deletion rust/tvm-macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
*/

use proc_macro::TokenStream;
use proc_macro_error::proc_macro_error;

mod external;
mod import_module;
Expand All @@ -29,12 +30,14 @@ pub fn import_module(input: TokenStream) -> TokenStream {
import_module::macro_impl(input)
}

#[proc_macro_derive(Object, attributes(base, ref_name, type_key))]
#[proc_macro_error]
#[proc_macro_derive(Object, attributes(base, ref_name, type_key, no_derive))]
pub fn macro_impl(input: TokenStream) -> TokenStream {
// let input = proc_macro2::TokenStream::from(input);
TokenStream::from(object::macro_impl(input))
}

#[proc_macro_error]
#[proc_macro]
pub fn external(input: TokenStream) -> TokenStream {
external::macro_impl(input)
Expand Down
32 changes: 31 additions & 1 deletion rust/tvm-macros/src/object.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream {
.map(attr_to_str)
.expect("Failed to get type_key");

let derive = get_attr(&derive_input, "no_derive")
.map(|_| false)
.unwrap_or(true);

let ref_id = get_attr(&derive_input, "ref_name")
.map(|a| Ident::new(attr_to_str(a).value().as_str(), Span::call_site()))
.unwrap_or_else(|| {
Expand Down Expand Up @@ -75,6 +79,12 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream {
_ => panic!("derive only works for structs"),
};

let ref_derives = if derive {
quote! { #[derive(Debug, Clone)]}
} else {
quote! { #[derive(Clone)] }
};

let mut expanded = quote! {
unsafe impl #tvm_rt_crate::object::IsObject for #payload_id {
const TYPE_KEY: &'static str = #type_key;
Expand All @@ -87,7 +97,7 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream {
}
}

#[derive(Clone)]
#ref_derives
pub struct #ref_id(Option<#tvm_rt_crate::object::ObjectPtr<#payload_id>>);

impl #tvm_rt_crate::object::IsObjectRef for #ref_id {
Expand Down Expand Up @@ -185,5 +195,25 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream {

expanded.extend(base_tokens);

if derive {
let derives = quote! {
impl std::hash::Hash for #ref_id {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.0.hash(state)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possibly dumb comment (I'm not familiar with this macro) - is self.0 always the right thing to hash? Or is #ref_id sometimes a tuple with multiple fields, where field .1 and .2 etc. may be part of what distinguishes objects when hashing?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No we generate a wrapper new type over the ObjectPtr which contains 1 field always. ObjectPtr is currently dispatching to TVM's C++ hashing so we have cross language consistency.

}
}

impl std::cmp::PartialEq for #ref_id {
fn eq(&self, other: &Self) -> bool {
self.0 == other.0
}
}

impl std::cmp::Eq for #ref_id {}
};

expanded.extend(derives);
}

TokenStream::from(expanded)
}
15 changes: 14 additions & 1 deletion rust/tvm-rt/src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
*/

use std::convert::{TryFrom, TryInto};
use std::iter::{IntoIterator, Iterator};
use std::iter::{FromIterator, IntoIterator, Iterator};
use std::marker::PhantomData;

use crate::errors::Error;
Expand Down Expand Up @@ -82,6 +82,13 @@ impl<T: IsObjectRef> Array<T> {
}
}

impl<T: IsObjectRef> std::fmt::Debug for Array<T> {
fn fmt(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
let as_vec: Vec<T> = self.clone().into_iter().collect();
write!(formatter, "{:?}", as_vec)
}
}

pub struct IntoIter<T: IsObjectRef> {
array: Array<T>,
pos: isize,
Expand Down Expand Up @@ -118,6 +125,12 @@ impl<T: IsObjectRef> IntoIterator for Array<T> {
}
}

impl<T: IsObjectRef> FromIterator<T> for Array<T> {
fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self {
Array::from_vec(iter.into_iter().collect()).unwrap()
}
}

impl<T: IsObjectRef> From<Array<T>> for ArgValue<'static> {
fn from(array: Array<T>) -> ArgValue<'static> {
array.object.into()
Expand Down
2 changes: 0 additions & 2 deletions rust/tvm-rt/src/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,6 @@ where
// TODO(@jroesch): convert to use generics instead of casting inside
// the implementation.
external! {
#[name("node.ArrayGetItem")]
fn array_get_item(array: ObjectRef, index: isize) -> ObjectRef;
#[name("node.MapSize")]
fn map_size(map: ObjectRef) -> i64;
#[name("node.MapGetItem")]
Expand Down
2 changes: 1 addition & 1 deletion rust/tvm-rt/src/ndarray.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ use crate::object::{Object, ObjectPtr};

/// See the [`module-level documentation`](../ndarray/index.html) for more details.
#[repr(C)]
#[derive(Object)]
#[derive(Object, Debug)]
#[ref_name = "NDArray"]
#[type_key = "runtime.NDArray"]
pub struct NDArrayContainer {
Expand Down
12 changes: 4 additions & 8 deletions rust/tvm-rt/src/object/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ pub trait IsObjectRef:
+ TryFrom<RetValue, Error = Error>
+ for<'a> Into<ArgValue<'a>>
+ for<'a> TryFrom<ArgValue<'a>, Error = Error>
+ std::fmt::Debug
{
type Object: IsObject;
fn as_ptr(&self) -> Option<&ObjectPtr<Self::Object>>;
Expand Down Expand Up @@ -88,14 +89,9 @@ pub trait IsObjectRef:

external! {
#[name("ir.DebugPrint")]
fn debug_print(object: ObjectRef) -> CString;
pub fn debug_print(object: ObjectRef) -> CString;
#[name("node.StructuralHash")]
fn structural_hash(object: ObjectRef, map_free_vars: bool) -> ObjectRef;
fn structural_hash(object: ObjectRef, map_free_vars: bool) -> i64;
#[name("node.StructuralEqual")]
fn structural_equal(lhs: ObjectRef, rhs: ObjectRef, assert_mode: bool, map_free_vars: bool) -> ObjectRef;
fn structural_equal(lhs: ObjectRef, rhs: ObjectRef, assert_mode: bool, map_free_vars: bool) -> bool;
}

// external! {
// #[name("ir.TextPrinter")]
// fn as_text(object: ObjectRef) -> CString;
// }
40 changes: 39 additions & 1 deletion rust/tvm-rt/src/object/object_ptr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

use std::convert::TryFrom;
use std::ffi::CString;
use std::fmt;
use std::ptr::NonNull;
use std::sync::atomic::AtomicI32;

Expand Down Expand Up @@ -147,14 +148,26 @@ impl Object {
}
}

// impl fmt::Debug for Object {
// fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
// let index =
// format!("{} // key: {}", self.type_index, "the_key");

// f.debug_struct("Object")
// .field("type_index", &index)
// // TODO(@jroesch: do we expose other fields?)
// .finish()
// }
// }

/// An unsafe trait which should be implemented for an object
/// subtype.
///
/// The trait contains the type key needed to compute the type
/// index, a method for accessing the base object given the
/// subtype, and a typed delete method which is specialized
/// to the subtype.
pub unsafe trait IsObject: AsRef<Object> {
pub unsafe trait IsObject: AsRef<Object> + std::fmt::Debug {
const TYPE_KEY: &'static str;

unsafe extern "C" fn typed_delete(object: *mut Self) {
Expand Down Expand Up @@ -264,6 +277,13 @@ impl<T: IsObject> std::ops::Deref for ObjectPtr<T> {
}
}

impl<T: IsObject> fmt::Debug for ObjectPtr<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
use std::ops::Deref;
write!(f, "{:?}", self.deref())
}
}

impl<'a, T: IsObject> From<ObjectPtr<T>> for RetValue {
fn from(object_ptr: ObjectPtr<T>) -> RetValue {
let raw_object_ptr = ObjectPtr::leak(object_ptr) as *mut T as *mut std::ffi::c_void;
Expand Down Expand Up @@ -342,6 +362,24 @@ impl<'a, T: IsObject> TryFrom<ArgValue<'a>> for ObjectPtr<T> {
}
}

impl<T: IsObject> std::hash::Hash for ObjectPtr<T> {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
state.write_i64(
super::structural_hash(ObjectRef(Some(self.clone().upcast())), false).unwrap(),
)
}
}

impl<T: IsObject> PartialEq for ObjectPtr<T> {
fn eq(&self, other: &Self) -> bool {
let lhs = ObjectRef(Some(self.clone().upcast()));
let rhs = ObjectRef(Some(other.clone().upcast()));
super::structural_equal(lhs, rhs, false, false).unwrap()
}
}

impl<T: IsObject> Eq for ObjectPtr<T> {}

#[cfg(test)]
mod tests {
use super::{Object, ObjectPtr};
Expand Down
3 changes: 2 additions & 1 deletion rust/tvm-rt/src/string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,10 @@ use super::Object;
use tvm_macros::Object;

#[repr(C)]
#[derive(Object)]
#[derive(Object, Debug)]
#[ref_name = "String"]
#[type_key = "runtime.String"]
#[no_derive]
pub struct StringObj {
base: Object,
data: *const u8,
Expand Down
1 change: 0 additions & 1 deletion rust/tvm-rt/src/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
//! `RetValue` is the owned version of `TVMPODValue`.

use std::convert::TryFrom;
// use std::ffi::c_void;

use crate::{ArgValue, Module, RetValue};
use tvm_sys::{errors::ValueDowncastError, ffi::TVMModuleHandle, try_downcast};
Expand Down
Loading