diff --git a/rust/Cargo.toml b/rust/Cargo.toml index 7c092d860b50..e75150859f90 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -27,7 +27,6 @@ members = [ "tvm-graph-rt", "tvm-graph-rt/tests/test_tvm_basic", "tvm-graph-rt/tests/test_tvm_dso", - "tvm-graph-rt/tests/test_wasm32", "tvm-graph-rt/tests/test_nn", "compiler-ext", ] diff --git a/rust/tvm-graph-rt/tests/test_wasm32/Cargo.toml b/rust/tvm-graph-rt/tests/test_wasm32/Cargo.toml index aed467f1235d..02e77d106f28 100644 --- a/rust/tvm-graph-rt/tests/test_wasm32/Cargo.toml +++ b/rust/tvm-graph-rt/tests/test_wasm32/Cargo.toml @@ -23,7 +23,7 @@ authors = ["TVM Contributors"] edition = "2018" [dependencies] -ndarray="0.12" +ndarray = "0.12" tvm-graph-rt = { path = "../../" } [build-dependencies] diff --git a/rust/tvm-macros/Cargo.toml b/rust/tvm-macros/Cargo.toml index 63b84727c525..e491177d8599 100644 --- a/rust/tvm-macros/Cargo.toml +++ b/rust/tvm-macros/Cargo.toml @@ -33,5 +33,5 @@ proc-macro = true goblin = "^0.2" proc-macro2 = "^1.0" quote = "^1.0" -syn = { version = "1.0.17", features = ["full", "extra-traits"] } +syn = { version = "1.0.48", features = ["full", "parsing", "extra-traits"] } proc-macro-error = "^1.0" diff --git a/rust/tvm-macros/src/external.rs b/rust/tvm-macros/src/external.rs index 802d7aeb6779..146f9d4d6bc6 100644 --- a/rust/tvm-macros/src/external.rs +++ b/rust/tvm-macros/src/external.rs @@ -17,12 +17,35 @@ * under the License. */ use proc_macro2::Span; +use proc_macro_error::abort; use quote::quote; use syn::parse::{Parse, ParseStream, Result}; -use syn::{FnArg, Generics, Ident, Lit, Meta, NestedMeta, Pat, ReturnType, TraitItemMethod, Type}; +use syn::{ + token::Semi, Attribute, FnArg, Generics, Ident, Lit, Meta, NestedMeta, Pat, ReturnType, + Signature, Type, Visibility, +}; + +struct ExternalItem { + attrs: Vec, + visibility: Visibility, + sig: Signature, +} + +impl Parse for ExternalItem { + fn parse(input: ParseStream) -> Result { + let item = ExternalItem { + attrs: input.call(Attribute::parse_outer)?, + visibility: input.parse()?, + sig: input.parse()?, + }; + let _semi: Semi = input.parse()?; + Ok(item) + } +} struct External { + visibility: Visibility, tvm_name: String, ident: Ident, generics: Generics, @@ -32,7 +55,8 @@ struct External { impl Parse for External { fn parse(input: ParseStream) -> Result { - let method: TraitItemMethod = input.parse()?; + let method: ExternalItem = input.parse()?; + let visibility = method.visibility; assert_eq!(method.attrs.len(), 1); let sig = method.sig; let tvm_name = method.attrs[0].parse_meta()?; @@ -47,8 +71,7 @@ impl Parse for External { } _ => panic!(), }; - assert_eq!(method.default, None); - assert!(method.semi_token != None); + let ident = sig.ident; let generics = sig.generics; let inputs = sig @@ -60,6 +83,7 @@ impl Parse for External { let ret_type = sig.output; Ok(External { + visibility, tvm_name, ident, generics, @@ -98,6 +122,7 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> proc_macro::TokenStream { let mut items = Vec::new(); for external in &ext_input.externs { + let visibility = &external.visibility; let name = &external.ident; let global_name = format!("global_{}", external.ident); let global_name = Ident::new(&global_name, Span::call_site()); @@ -109,7 +134,9 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> proc_macro::TokenStream { .iter() .map(|ty_param| match ty_param { syn::GenericParam::Type(param) => param.clone(), - _ => panic!(), + _ => abort! { ty_param, + "Only supports type parameters." + }, }) .collect(); @@ -124,15 +151,21 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> proc_macro::TokenStream { let ty: Type = *pat_type.ty.clone(); (ident, ty) } - _ => panic!(), + _ => abort! { pat_type, + "Only supports type parameters." + }, + }, + pat => abort! { + pat, "invalid pattern type for function"; + + note = "{:?} is not allowed here", pat; }, - _ => panic!(), }) .unzip(); let ret_type = match &external.ret_type { ReturnType::Type(_, rtype) => *rtype.clone(), - _ => panic!(), + ReturnType::Default => syn::parse_str::("()").unwrap(), }; let global = quote! { @@ -147,7 +180,7 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> proc_macro::TokenStream { items.push(global); let wrapper = quote! { - pub fn #name<#(#ty_params),*>(#(#args : #tys),*) -> #result_type<#ret_type> { + #visibility fn #name<#(#ty_params),*>(#(#args : #tys),*) -> #result_type<#ret_type> { let func_ref: #tvm_rt_crate::Function = #global_name.clone(); let func_ref: Box #result_type<#ret_type>> = func_ref.into(); let res: #ret_type = func_ref(#(#args),*)?; diff --git a/rust/tvm-macros/src/lib.rs b/rust/tvm-macros/src/lib.rs index 603e1ceaafcc..e563a57f149e 100644 --- a/rust/tvm-macros/src/lib.rs +++ b/rust/tvm-macros/src/lib.rs @@ -18,6 +18,7 @@ */ use proc_macro::TokenStream; +use proc_macro_error::proc_macro_error; mod external; mod import_module; @@ -29,12 +30,14 @@ pub fn import_module(input: TokenStream) -> TokenStream { import_module::macro_impl(input) } -#[proc_macro_derive(Object, attributes(base, ref_name, type_key))] +#[proc_macro_error] +#[proc_macro_derive(Object, attributes(base, ref_name, type_key, no_derive))] pub fn macro_impl(input: TokenStream) -> TokenStream { // let input = proc_macro2::TokenStream::from(input); TokenStream::from(object::macro_impl(input)) } +#[proc_macro_error] #[proc_macro] pub fn external(input: TokenStream) -> TokenStream { external::macro_impl(input) diff --git a/rust/tvm-macros/src/object.rs b/rust/tvm-macros/src/object.rs index ff72d6a649be..c84d0aab612f 100644 --- a/rust/tvm-macros/src/object.rs +++ b/rust/tvm-macros/src/object.rs @@ -36,6 +36,10 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream { .map(attr_to_str) .expect("Failed to get type_key"); + let derive = get_attr(&derive_input, "no_derive") + .map(|_| false) + .unwrap_or(true); + let ref_id = get_attr(&derive_input, "ref_name") .map(|a| Ident::new(attr_to_str(a).value().as_str(), Span::call_site())) .unwrap_or_else(|| { @@ -75,6 +79,12 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream { _ => panic!("derive only works for structs"), }; + let ref_derives = if derive { + quote! { #[derive(Debug, Clone)]} + } else { + quote! { #[derive(Clone)] } + }; + let mut expanded = quote! { unsafe impl #tvm_rt_crate::object::IsObject for #payload_id { const TYPE_KEY: &'static str = #type_key; @@ -87,7 +97,7 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream { } } - #[derive(Clone)] + #ref_derives pub struct #ref_id(Option<#tvm_rt_crate::object::ObjectPtr<#payload_id>>); impl #tvm_rt_crate::object::IsObjectRef for #ref_id { @@ -185,5 +195,25 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream { expanded.extend(base_tokens); + if derive { + let derives = quote! { + impl std::hash::Hash for #ref_id { + fn hash(&self, state: &mut H) { + self.0.hash(state) + } + } + + impl std::cmp::PartialEq for #ref_id { + fn eq(&self, other: &Self) -> bool { + self.0 == other.0 + } + } + + impl std::cmp::Eq for #ref_id {} + }; + + expanded.extend(derives); + } + TokenStream::from(expanded) } diff --git a/rust/tvm-rt/src/array.rs b/rust/tvm-rt/src/array.rs index 98414f9c5b34..1b0ce8399d1f 100644 --- a/rust/tvm-rt/src/array.rs +++ b/rust/tvm-rt/src/array.rs @@ -18,7 +18,7 @@ */ use std::convert::{TryFrom, TryInto}; -use std::iter::{IntoIterator, Iterator}; +use std::iter::{FromIterator, IntoIterator, Iterator}; use std::marker::PhantomData; use crate::errors::Error; @@ -82,6 +82,13 @@ impl Array { } } +impl std::fmt::Debug for Array { + fn fmt(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + let as_vec: Vec = self.clone().into_iter().collect(); + write!(formatter, "{:?}", as_vec) + } +} + pub struct IntoIter { array: Array, pos: isize, @@ -118,6 +125,12 @@ impl IntoIterator for Array { } } +impl FromIterator for Array { + fn from_iter>(iter: I) -> Self { + Array::from_vec(iter.into_iter().collect()).unwrap() + } +} + impl From> for ArgValue<'static> { fn from(array: Array) -> ArgValue<'static> { array.object.into() diff --git a/rust/tvm-rt/src/map.rs b/rust/tvm-rt/src/map.rs index 721fb1ec4588..b8bfb4e5e644 100644 --- a/rust/tvm-rt/src/map.rs +++ b/rust/tvm-rt/src/map.rs @@ -48,8 +48,6 @@ where // TODO(@jroesch): convert to use generics instead of casting inside // the implementation. external! { - #[name("node.ArrayGetItem")] - fn array_get_item(array: ObjectRef, index: isize) -> ObjectRef; #[name("node.MapSize")] fn map_size(map: ObjectRef) -> i64; #[name("node.MapGetItem")] diff --git a/rust/tvm-rt/src/ndarray.rs b/rust/tvm-rt/src/ndarray.rs index ed280ccc2d80..07f783f0ef43 100644 --- a/rust/tvm-rt/src/ndarray.rs +++ b/rust/tvm-rt/src/ndarray.rs @@ -65,7 +65,7 @@ use crate::object::{Object, ObjectPtr}; /// See the [`module-level documentation`](../ndarray/index.html) for more details. #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "NDArray"] #[type_key = "runtime.NDArray"] pub struct NDArrayContainer { diff --git a/rust/tvm-rt/src/object/mod.rs b/rust/tvm-rt/src/object/mod.rs index 46e034232a63..8c07ed9f0853 100644 --- a/rust/tvm-rt/src/object/mod.rs +++ b/rust/tvm-rt/src/object/mod.rs @@ -40,6 +40,7 @@ pub trait IsObjectRef: + TryFrom + for<'a> Into> + for<'a> TryFrom, Error = Error> + + std::fmt::Debug { type Object: IsObject; fn as_ptr(&self) -> Option<&ObjectPtr>; @@ -88,14 +89,9 @@ pub trait IsObjectRef: external! { #[name("ir.DebugPrint")] - fn debug_print(object: ObjectRef) -> CString; + pub fn debug_print(object: ObjectRef) -> CString; #[name("node.StructuralHash")] - fn structural_hash(object: ObjectRef, map_free_vars: bool) -> ObjectRef; + fn structural_hash(object: ObjectRef, map_free_vars: bool) -> i64; #[name("node.StructuralEqual")] - fn structural_equal(lhs: ObjectRef, rhs: ObjectRef, assert_mode: bool, map_free_vars: bool) -> ObjectRef; + fn structural_equal(lhs: ObjectRef, rhs: ObjectRef, assert_mode: bool, map_free_vars: bool) -> bool; } - -// external! { -// #[name("ir.TextPrinter")] -// fn as_text(object: ObjectRef) -> CString; -// } diff --git a/rust/tvm-rt/src/object/object_ptr.rs b/rust/tvm-rt/src/object/object_ptr.rs index 8d535368c352..8df6041956b8 100644 --- a/rust/tvm-rt/src/object/object_ptr.rs +++ b/rust/tvm-rt/src/object/object_ptr.rs @@ -19,6 +19,7 @@ use std::convert::TryFrom; use std::ffi::CString; +use std::fmt; use std::ptr::NonNull; use std::sync::atomic::AtomicI32; @@ -147,6 +148,18 @@ impl Object { } } +// impl fmt::Debug for Object { +// fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { +// let index = +// format!("{} // key: {}", self.type_index, "the_key"); + +// f.debug_struct("Object") +// .field("type_index", &index) +// // TODO(@jroesch: do we expose other fields?) +// .finish() +// } +// } + /// An unsafe trait which should be implemented for an object /// subtype. /// @@ -154,7 +167,7 @@ impl Object { /// index, a method for accessing the base object given the /// subtype, and a typed delete method which is specialized /// to the subtype. -pub unsafe trait IsObject: AsRef { +pub unsafe trait IsObject: AsRef + std::fmt::Debug { const TYPE_KEY: &'static str; unsafe extern "C" fn typed_delete(object: *mut Self) { @@ -264,6 +277,13 @@ impl std::ops::Deref for ObjectPtr { } } +impl fmt::Debug for ObjectPtr { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + use std::ops::Deref; + write!(f, "{:?}", self.deref()) + } +} + impl<'a, T: IsObject> From> for RetValue { fn from(object_ptr: ObjectPtr) -> RetValue { let raw_object_ptr = ObjectPtr::leak(object_ptr) as *mut T as *mut std::ffi::c_void; @@ -342,6 +362,24 @@ impl<'a, T: IsObject> TryFrom> for ObjectPtr { } } +impl std::hash::Hash for ObjectPtr { + fn hash(&self, state: &mut H) { + state.write_i64( + super::structural_hash(ObjectRef(Some(self.clone().upcast())), false).unwrap(), + ) + } +} + +impl PartialEq for ObjectPtr { + fn eq(&self, other: &Self) -> bool { + let lhs = ObjectRef(Some(self.clone().upcast())); + let rhs = ObjectRef(Some(other.clone().upcast())); + super::structural_equal(lhs, rhs, false, false).unwrap() + } +} + +impl Eq for ObjectPtr {} + #[cfg(test)] mod tests { use super::{Object, ObjectPtr}; diff --git a/rust/tvm-rt/src/string.rs b/rust/tvm-rt/src/string.rs index 3cd33a226d44..e61afaf7399b 100644 --- a/rust/tvm-rt/src/string.rs +++ b/rust/tvm-rt/src/string.rs @@ -25,9 +25,10 @@ use super::Object; use tvm_macros::Object; #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "String"] #[type_key = "runtime.String"] +#[no_derive] pub struct StringObj { base: Object, data: *const u8, diff --git a/rust/tvm-rt/src/value.rs b/rust/tvm-rt/src/value.rs index c49944dc7e33..b8cd190176c4 100644 --- a/rust/tvm-rt/src/value.rs +++ b/rust/tvm-rt/src/value.rs @@ -22,7 +22,6 @@ //! `RetValue` is the owned version of `TVMPODValue`. use std::convert::TryFrom; -// use std::ffi::c_void; use crate::{ArgValue, Module, RetValue}; use tvm_sys::{errors::ValueDowncastError, ffi::TVMModuleHandle, try_downcast}; diff --git a/rust/tvm-sys/src/datatype.rs b/rust/tvm-sys/src/datatype.rs index 8050d932e5c1..5f7e0c3a3b60 100644 --- a/rust/tvm-sys/src/datatype.rs +++ b/rust/tvm-sys/src/datatype.rs @@ -83,6 +83,10 @@ impl DataType { DataType::new(DL_FLOAT_CODE, bits, lanes) } + pub const fn float32() -> DataType { + Self::float(32, 1) + } + pub const fn uint(bits: u8, lanes: u16) -> DataType { DataType::new(DL_UINT_CODE, bits, lanes) } diff --git a/rust/tvm/src/ir/arith.rs b/rust/tvm/src/ir/arith.rs index 92a1de69ff78..672e6e6113a0 100644 --- a/rust/tvm/src/ir/arith.rs +++ b/rust/tvm/src/ir/arith.rs @@ -24,7 +24,7 @@ use tvm_macros::Object; macro_rules! define_node { ($name:ident, $ref:expr, $typekey:expr; $node:ident { $($id:ident : $t:ty),*}) => { #[repr(C)] - #[derive(Object)] + #[derive(Object, Debug)] #[ref_name = $ref] #[type_key = $typekey] pub struct $node { diff --git a/rust/tvm/src/ir/attrs.rs b/rust/tvm/src/ir/attrs.rs index 5bd027ab4b4c..739ed405c906 100644 --- a/rust/tvm/src/ir/attrs.rs +++ b/rust/tvm/src/ir/attrs.rs @@ -21,7 +21,7 @@ use crate::runtime::Object; use tvm_macros::Object; #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "Attrs"] #[type_key = "Attrs"] pub struct BaseAttrsNode { diff --git a/rust/tvm/src/ir/diagnostics/mod.rs b/rust/tvm/src/ir/diagnostics/mod.rs index 051bb9eb16c4..8bcdf8f51e60 100644 --- a/rust/tvm/src/ir/diagnostics/mod.rs +++ b/rust/tvm/src/ir/diagnostics/mod.rs @@ -59,6 +59,7 @@ external! { /// The diagnostic level, controls the printing of the message. #[repr(C)] +#[derive(PartialEq, Eq, Debug)] pub enum DiagnosticLevel { Bug = 10, Error = 20, @@ -69,7 +70,7 @@ pub enum DiagnosticLevel { /// A compiler diagnostic. #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "Diagnostic"] #[type_key = "Diagnostic"] pub struct DiagnosticNode { @@ -145,7 +146,7 @@ impl DiagnosticBuilder { /// of compiler diagnostics to std::out and std::err in /// a human readable form. #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "DiagnosticRenderer"] #[type_key = "DiagnosticRenderer"] /// A diagnostic renderer, which given a diagnostic context produces a "rendered" @@ -166,7 +167,7 @@ impl DiagnosticRenderer { } #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "DiagnosticContext"] #[type_key = "DiagnosticContext"] /// A diagnostic context for recording errors against a source file. diff --git a/rust/tvm/src/ir/expr.rs b/rust/tvm/src/ir/expr.rs index f74522d91c70..653169def3a4 100644 --- a/rust/tvm/src/ir/expr.rs +++ b/rust/tvm/src/ir/expr.rs @@ -17,15 +17,17 @@ * under the License. */ -use super::relay; +use tvm_macros::Object; + use crate::runtime::String as TString; use crate::runtime::{self, external, IsObject, IsObjectRef, Object, ObjectPtr, ObjectRef}; use crate::DataType; -use tvm_macros::Object; +use super::relay; +use super::span::Span; #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "BaseExpr"] #[type_key = "Expr"] pub struct BaseExprNode { @@ -41,7 +43,7 @@ impl BaseExprNode { } #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "PrimExpr"] #[type_key = "PrimExpr"] pub struct PrimExprNode { @@ -59,7 +61,7 @@ impl PrimExprNode { } #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "GlobalVar"] #[type_key = "GlobalVar"] pub struct GlobalVarNode { @@ -68,7 +70,7 @@ pub struct GlobalVarNode { } impl GlobalVar { - pub fn new(name_hint: String, _span: ObjectRef) -> GlobalVar { + pub fn new(name_hint: String, _span: Span) -> GlobalVar { let node = GlobalVarNode { base: relay::ExprNode::base::(), name_hint: name_hint.into(), diff --git a/rust/tvm/src/ir/function.rs b/rust/tvm/src/ir/function.rs index 3043bf9e7cff..14c00ea02bf6 100644 --- a/rust/tvm/src/ir/function.rs +++ b/rust/tvm/src/ir/function.rs @@ -28,7 +28,7 @@ use tvm_macros::Object; pub type DictAttrs = ObjectRef; #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "BaseFunc"] #[type_key = "BaseFunc"] pub struct BaseFuncNode { diff --git a/rust/tvm/src/ir/module.rs b/rust/tvm/src/ir/module.rs index 190b477b98f2..a09f70dc25b9 100644 --- a/rust/tvm/src/ir/module.rs +++ b/rust/tvm/src/ir/module.rs @@ -16,6 +16,9 @@ * specific language governing permissions and limitations * under the License. */ + +use std::collections::HashMap; +use std::iter::FromIterator; use std::path::Path; use thiserror::Error; @@ -25,15 +28,12 @@ use crate::runtime::array::Array; use crate::runtime::function::Result; use crate::runtime::map::Map; use crate::runtime::string::String as TVMString; -use crate::runtime::{external, Object, ObjectRef}; +use crate::runtime::{external, IsObjectRef, Object}; use super::expr::GlobalVar; use super::function::BaseFunc; use super::source_map::SourceMap; - -// TODO(@jroesch): define type -type TypeData = ObjectRef; -type GlobalTypeVar = ObjectRef; +use super::{relay, ty::GlobalTypeVar, ty::TypeData}; #[derive(Error, Debug)] pub enum Error { @@ -44,7 +44,7 @@ pub enum Error { } #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "IRModule"] #[type_key = "IRModule"] pub struct IRModuleNode { @@ -61,7 +61,11 @@ external! { fn parse_module(file_name: TVMString, source: TVMString) -> IRModule; #[name("parser.ParseExpr")] fn parse_expression(file_name: TVMString, source: TVMString) -> IRModule; + #[name("ir.IRModule")] + fn module_new(funcs: Map, types: Map) -> IRModule; // Module methods + #[name("ir.Module_Add")] + fn module_add(module: IRModule, type_name: GlobalVar, expr: BaseFunc, update: bool) -> IRModule; #[name("ir.Module_AddDef")] fn module_add_def(module: IRModule, type_name: GlobalTypeVar, type_data: TypeData, update: bool) -> (); #[name("ir.Module_GetGlobalVar")] @@ -72,57 +76,43 @@ external! { fn module_lookup(module: IRModule, var: GlobalVar) -> BaseFunc; #[name("ir.Module_Lookup_str")] fn module_lookup_str(module: IRModule, name: TVMString) -> BaseFunc; + #[name("ir.Module_GetGlobalTypeVars")] + fn module_get_global_type_vars(module: IRModule) -> Array; + #[name("ir.Module_ContainGlobalVar")] + fn module_contains_global_var(module: IRModule, name: TVMString) -> bool; + #[name("ir.Module_ContainGlobalTypeVar")] + fn module_contains_global_type_var(module: IRModule, name: TVMString) -> bool; + #[name("ir.Module_LookupDef")] + fn module_lookup_def(module: IRModule, global: GlobalTypeVar) -> TypeData; + #[name("ir.Module_LookupDef_str")] + fn module_lookup_def_str(module: IRModule, global: TVMString) -> TypeData; + #[name("ir.Module_LookupTag")] + fn module_lookup_tag(module: IRModule, tag: i32) -> relay::Constructor; + #[name("ir.Module_FromExpr")] + fn module_from_expr(expr: relay::Expr, funcs: Map, types: Map) -> IRModule; + #[name("ir.Module_Import")] + fn module_import(module: IRModule, path: TVMString); + #[name("ir.Module_ImportFromStd")] + fn module_import_from_std(module: IRModule, path: TVMString); } -// TVM_REGISTER_GLOBAL("ir.Module_GetGlobalTypeVars") -// .set_body_method(&IRModuleNode::GetGlobalTypeVars); - -// TVM_REGISTER_GLOBAL("ir.Module_ContainGlobalVar") -// .set_body_method(&IRModuleNode::ContainGlobalVar); - -// TVM_REGISTER_GLOBAL("ir.Module_GetGlobalTypeVar") -// .set_body_method(&IRModuleNode::GetGlobalTypeVar); - -// TVM_REGISTER_GLOBAL("ir.Module_LookupDef").set_body_typed([](IRModule mod, GlobalTypeVar var) { -// return mod->LookupTypeDef(var); -// }); - -// TVM_REGISTER_GLOBAL("ir.Module_LookupDef_str").set_body_typed([](IRModule mod, String var) { -// return mod->LookupTypeDef(var); -// }); +// Note: we don't expose update here as update is going to be removed. -// TVM_REGISTER_GLOBAL("ir.Module_LookupTag").set_body_typed([](IRModule mod, int32_t tag) { -// return mod->LookupTag(tag); -// }); - -// TVM_REGISTER_GLOBAL("ir.Module_FromExpr") -// .set_body_typed([](RelayExpr e, tvm::Map funcs, -// tvm::Map type_defs) { -// return IRModule::FromExpr(e, funcs, type_defs); -// }); - -// TVM_REGISTER_GLOBAL("ir.Module_Update").set_body_typed([](IRModule mod, IRModule from) { -// mod->Update(from); -// }); - -// TVM_REGISTER_GLOBAL("ir.Module_UpdateFunction") -// .set_body_typed([](IRModule mod, GlobalVar gv, BaseFunc func) { mod->Update(gv, func); }); - -// TVM_REGISTER_GLOBAL("ir.Module_Import").set_body_typed([](IRModule mod, String path) { -// mod->Import(path); -// }); - -// TVM_REGISTER_GLOBAL("ir.Module_ImportFromStd").set_body_typed([](IRModule mod, String path) { -// mod->ImportFromStd(path); -// }); +impl IRModule { + pub fn new(funcs: F, types: T) -> Result + where + F: IntoIterator, + T: IntoIterator, + { + module_new(Map::from_iter(funcs), Map::from_iter(types)) + } -// TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -// .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { -// auto* node = static_cast(ref.get()); -// p->stream << "IRModuleNode( " << node->functions << ")"; -// }); + pub fn empty() -> Result { + let funcs = HashMap::::new(); + let types = HashMap::::new(); + IRModule::new(funcs, types) + } -impl IRModule { pub fn parse(file_name: N, source: S) -> Result where N: Into, @@ -141,6 +131,15 @@ impl IRModule { Ok(module) } + pub fn add(&mut self, var: GlobalVar, func: F) -> Result + // todo(@jroesch): can we do better here? why doesn't BaseFunc::Object work? + where + F: IsObjectRef, + F::Object: AsRef<::Object>, + { + module_add(self.clone(), var, func.upcast(), true) + } + pub fn add_def( &mut self, type_name: GlobalTypeVar, @@ -150,8 +149,11 @@ impl IRModule { module_add_def(self.clone(), type_name, type_data, update) } - pub fn get_global_var(&self, name: TVMString) -> Result { - module_get_global_var(self.clone(), name) + pub fn get_global_var(&self, name: S) -> Result + where + S: Into, + { + module_get_global_var(self.clone(), name.into()) } pub fn get_global_vars(&self) -> Result> { @@ -168,4 +170,216 @@ impl IRModule { { module_lookup_str(self.clone(), name.into()) } + + pub fn get_global_type_vars(&self) -> Result> { + module_get_global_type_vars(self.clone()) + } + + pub fn contains_global_var>(&self, name: S) -> Result { + module_contains_global_var(self.clone(), name.into()) + } + + pub fn contains_global_type_var>(&self, name: S) -> Result { + module_contains_global_type_var(self.clone(), name.into()) + } + + pub fn lookup_def(&self, global: GlobalTypeVar) -> Result { + module_lookup_def(self.clone(), global) + } + + pub fn lookup_def_str(&self, global: S) -> Result + where + S: Into, + { + module_lookup_def_str(self.clone(), global.into()) + } + + pub fn lookup_tag(&self, tag: i32) -> Result { + module_lookup_tag(self.clone(), tag) + } + + pub fn from_expr(expr: E) -> Result + where + E: IsObjectRef, + E::Object: AsRef<::Object>, + { + Self::from_expr_with_items(expr, HashMap::new(), HashMap::new()) + } + + pub fn from_expr_with_items(expr: E, funcs: F, types: T) -> Result + where + F: IntoIterator, + T: IntoIterator, + E: IsObjectRef, + E::Object: AsRef<::Object>, + { + module_from_expr(expr.upcast(), Map::from_iter(funcs), Map::from_iter(types)) + } + + pub fn import>(&mut self, path: S) -> Result<()> { + module_import(self.clone(), path.into()) + } + + pub fn import_from_std>(&mut self, path: S) -> Result<()> { + module_import_from_std(self.clone(), path.into()) + } +} + +#[cfg(test)] +mod tests { + use super::relay::*; + use super::*; + use crate::ir::span::Span; + use crate::ir::ty::{GlobalTypeVar, TypeData, TypeKind}; + use tvm_rt::IsObjectRef; + + fn add_dummy_functions(names: Vec<&str>) -> Result { + let mut module = IRModule::empty()?; + let x = Var::static_tensor("x".into(), vec![1, 1], DataType::float32()); + let params = vec![x.clone()]; + let func = relay::Function::simple(params, x); + + for name in names { + let gv = GlobalVar::new(name.into(), Span::null()); + module = module.add(gv, func.clone())?; + } + + Ok(module) + } + + fn add_dummy_types(names: Vec<&str>) -> Result { + let mut module = IRModule::empty()?; + + for name in names { + let name: String = name.into(); + let name = GlobalTypeVar::new(name, TypeKind::Type, Span::null()); + let type_data = TypeData::new(name.clone(), vec![], vec![], Span::null()); + module.add_def(name, type_data, true)?; + } + + Ok(module) + } + + #[test] + fn test_module_add() -> anyhow::Result<()> { + let mut module = IRModule::empty()?; + let x = Var::static_tensor("x".into(), vec![1, 1], DataType::float32()); + let params = vec![x.clone()]; + let func = relay::Function::simple(params, x); + let module = module.add(GlobalVar::new("foo".into(), Span::null()), func)?; + let lfunc = module.lookup_str("foo")?; + let lfunc = lfunc.downcast::()?; + assert_eq!(lfunc.params.len(), 1); + Ok(()) + } + + #[test] + fn test_module_add_def() -> Result<()> { + let mut module = IRModule::empty()?; + let name = GlobalTypeVar::new("my_type", TypeKind::Type, Span::null()); + let type_data = TypeData::new(name.clone(), vec![], vec![], Span::null()); + module.add_def(name.clone(), type_data, true)?; + let by_gtv = module.lookup_def(name)?; + let by_gv = module.lookup_def_str("my_type")?; + Ok(()) + } + + #[test] + fn test_get_global_var() -> Result<()> { + let mut module = IRModule::empty()?; + let x = Var::static_tensor("x".into(), vec![1, 1], DataType::float32()); + let params = vec![x.clone()]; + let func = relay::Function::simple(params, x); + let gv_foo = GlobalVar::new("foo".into(), Span::null()); + let module = module.add(gv_foo.clone(), func)?; + let gv = module.get_global_var("foo")?; + assert_eq!(gv_foo, gv); + Ok(()) + } + + #[test] + fn test_get_global_vars() -> Result<()> { + let names = vec!["foo", "bar", "baz"]; + let module = add_dummy_functions(names.clone())?; + let gvars: Vec = module + .get_global_vars()? + .into_iter() + .map(|gv| gv.name_hint.as_str().unwrap().to_string()) + .collect(); + + for name in names { + assert!(gvars.contains(&name.to_string())); + } + + Ok(()) + } + + #[test] + fn test_get_global_type_vars() -> Result<()> { + let names = vec!["foo", "bar", "baz"]; + let module = add_dummy_types(names.clone())?; + let gvars: Vec = module + .get_global_type_vars()? + .into_iter() + .map(|gv| gv.name_hint.as_str().unwrap().to_string()) + .collect(); + + for name in names { + assert!(gvars.contains(&name.to_string())); + } + + Ok(()) + } + + #[test] + fn test_contains_global_var() -> Result<()> { + let module = add_dummy_functions(vec!["foo"])?; + assert!(module.contains_global_var("foo")?); + Ok(()) + } + + #[test] + fn test_contains_global_type_var() -> Result<()> { + let module = add_dummy_types(vec!["foo"])?; + assert!(module.contains_global_type_var("foo")?); + Ok(()) + } + + // TODO(@jroesch): not really sure about this API at all. + // pub fn lookup_tag(&self, tag: i32) -> Result { + // module_lookup_tag(self.clone(), tag) + // } + + #[test] + fn test_from_expr() -> Result<()> { + let x = Var::static_tensor("x".into(), vec![1, 1], DataType::float32()); + let params = vec![x.clone()]; + let func = relay::Function::simple(params, x); + let module = IRModule::from_expr(func.clone())?; + let main_fn = module.lookup_str("main")?; + let main_fn = main_fn.downcast::()?; + assert_eq!(main_fn, func); + Ok(()) + } + + #[test] + fn test_import() -> Result<()> { + let mut std_path: String = env!("CARGO_MANIFEST_DIR").into(); + std_path += "/../../python/tvm/relay/std/prelude.rly"; + + let mut mod1 = IRModule::empty()?; + mod1.import(std_path.clone())?; + mod1.lookup_str("map")?; + + // TODO(@jroesch): this requires another patch of mine to enable. + + // if cfg!(feature = "python") { + // crate::python::load().unwrap(); + // let mut mod2 = IRModule::empty()?; + // mod2.import_from_std("prelude.rly")?; + // mod2.lookup_str("map")?; + // } + + Ok(()) + } } diff --git a/rust/tvm/src/ir/op.rs b/rust/tvm/src/ir/op.rs index d81d6a69c1eb..d222ead0391b 100644 --- a/rust/tvm/src/ir/op.rs +++ b/rust/tvm/src/ir/op.rs @@ -27,7 +27,7 @@ type FuncType = ObjectRef; type AttrFieldInfo = ObjectRef; #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "Op"] #[type_key = "Op"] pub struct OpNode { diff --git a/rust/tvm/src/ir/relay/attrs/nn.rs b/rust/tvm/src/ir/relay/attrs/nn.rs index cb96f0fbf588..7ecd92febc22 100644 --- a/rust/tvm/src/ir/relay/attrs/nn.rs +++ b/rust/tvm/src/ir/relay/attrs/nn.rs @@ -27,7 +27,7 @@ use tvm_macros::Object; type IndexExpr = PrimExpr; #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "Conv2DAttrs"] #[type_key = "relay.attrs.Conv2DAttrs"] pub struct Conv2DAttrsNode { @@ -46,7 +46,7 @@ pub struct Conv2DAttrsNode { } #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "BiasAddAttrs"] #[type_key = "relay.attrs.BiasAddAttrs"] pub struct BiasAddAttrsNode { @@ -55,7 +55,7 @@ pub struct BiasAddAttrsNode { } #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "DenseAttrs"] #[type_key = "relay.attrs.DenseAttrs"] pub struct DenseAttrsNode { @@ -65,7 +65,7 @@ pub struct DenseAttrsNode { } #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "GlobalPool2DAttrs"] #[type_key = "relay.attrs.GlobalPool2DAttrs"] pub struct GlobalPool2DAttrsNode { @@ -74,7 +74,7 @@ pub struct GlobalPool2DAttrsNode { } #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "MaxPool2DAttrs"] #[type_key = "relay.attrs.MaxPool2DAttrs"] pub struct MaxPool2DAttrsNode { @@ -87,7 +87,7 @@ pub struct MaxPool2DAttrsNode { } #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "SoftmaxAttrs"] #[type_key = "relay.attrs.SoftmaxAttrs"] pub struct SoftmaxAttrsNode { @@ -96,7 +96,7 @@ pub struct SoftmaxAttrsNode { } #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "BatchNormAttrs"] #[type_key = "relay.attrs.BatchNormAttrs"] pub struct BatchNormAttrsNode { diff --git a/rust/tvm/src/ir/relay/attrs/transform.rs b/rust/tvm/src/ir/relay/attrs/transform.rs index 863f07617778..c459f96b2d2f 100644 --- a/rust/tvm/src/ir/relay/attrs/transform.rs +++ b/rust/tvm/src/ir/relay/attrs/transform.rs @@ -21,7 +21,7 @@ use crate::ir::attrs::BaseAttrsNode; use tvm_macros::Object; #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "ExpandDimsAttrs"] #[type_key = "relay.attrs.ExpandDimsAttrs"] pub struct ExpandDimsAttrsNode { diff --git a/rust/tvm/src/ir/relay/mod.rs b/rust/tvm/src/ir/relay/mod.rs index cc1a76bef7e3..9d2983237acb 100644 --- a/rust/tvm/src/ir/relay/mod.rs +++ b/rust/tvm/src/ir/relay/mod.rs @@ -16,11 +16,6 @@ * specific language governing permissions and limitations * under the License. */ - -pub mod attrs; - -use std::hash::Hash; - use crate::runtime::array::Array; use crate::runtime::{object::*, IsObjectRef, String as TString}; @@ -34,9 +29,12 @@ use tvm_macros::Object; use tvm_rt::NDArray; pub use super::expr::{GlobalVar, GlobalVarNode}; +pub use crate::runtime::DataType; + +pub mod attrs; #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "Expr"] #[type_key = "RelayExpr"] pub struct ExprNode { @@ -58,22 +56,8 @@ impl ExprNode { } } -impl Hash for Expr { - fn hash(&self, state: &mut H) { - self.as_ptr().unwrap().ptr.hash(state) - } -} - -impl PartialEq for Expr { - fn eq(&self, other: &Self) -> bool { - self.as_ptr().unwrap().ptr.eq(&other.as_ptr().unwrap().ptr) - } -} - -impl Eq for Expr {} - #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "Id"] #[type_key = "relay.Id"] pub struct IdNode { @@ -92,7 +76,7 @@ impl Id { } #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "Constant"] #[type_key = "relay.Constant"] pub struct ConstantNode { @@ -111,7 +95,7 @@ impl Constant { } #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "Tuple"] #[type_key = "relay.Tuple"] pub struct TupleNode { @@ -130,7 +114,7 @@ impl Tuple { } #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "Var"] #[type_key = "relay.Var"] pub struct VarNode { @@ -140,11 +124,11 @@ pub struct VarNode { } impl Var { - pub fn new(name_hint: String, type_annotation: Type, _span: ObjectRef) -> Var { + pub fn new(name_hint: String, type_annotation: Type, _span: Span) -> Var { let node = VarNode { base: ExprNode::base::(), vid: Id::new(name_hint.into()), - type_annotation, + type_annotation: type_annotation, }; Var(Some(ObjectPtr::new(node))) } @@ -153,13 +137,18 @@ impl Var { &self.vid.0.as_ref().unwrap().name_hint } - pub fn to_expr(self) -> Expr { - unsafe { Expr(std::mem::transmute(self.0)) } + pub fn static_tensor(name_hint: String, sh: Vec, dtype: DataType) -> Var { + let sh = Array::from_vec(sh.into_iter().map(Into::into).collect()).unwrap(); + Self::new( + name_hint, + super::ty::TensorType::new(sh, dtype, Span::null()).upcast(), + Span::null(), + ) } } #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "Call"] #[type_key = "relay.Call"] pub struct CallNode { @@ -190,7 +179,7 @@ impl Call { } #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "Let"] #[type_key = "relay.Let"] pub struct LetNode { @@ -213,7 +202,7 @@ impl Let { } #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "If"] #[type_key = "relay.If"] pub struct IfNode { @@ -236,7 +225,7 @@ impl If { } #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "TupleGetItem"] #[type_key = "relay.TupleGetItem"] pub struct TupleGetItemNode { @@ -257,7 +246,7 @@ impl TupleGetItem { } #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "RefCreate"] #[type_key = "relay.RefCreate"] pub struct RefCreateNode { @@ -276,7 +265,7 @@ impl RefCreate { } #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "RefRead"] #[type_key = "relay.RefRead"] pub struct RefReadNode { @@ -295,7 +284,7 @@ impl RefRead { } #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "RefWrite"] #[type_key = "relay.RefWrite"] pub struct RefWriteNode { @@ -316,7 +305,7 @@ impl RefWrite { } #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "Constructor"] #[type_key = "relay.Constructor"] pub struct ConstructorNode { @@ -341,7 +330,7 @@ impl Constructor { // TODO(@jroesch): define the type data #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "Pattern"] #[type_key = "relay.Pattern"] pub struct PatternNode { @@ -359,7 +348,7 @@ impl PatternNode { } #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "PatternWildcard"] #[type_key = "relay.PatternWildcard"] pub struct PatternWildcardNode { @@ -376,7 +365,7 @@ impl PatternWildcard { } #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "PatternVar"] #[type_key = "relay.PatternVar"] pub struct PatternVarNode { @@ -395,7 +384,7 @@ impl PatternVar { } #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "PatternConstructor"] #[type_key = "relay.PatternConstructor"] pub struct PatternConstructorNode { @@ -420,7 +409,7 @@ impl PatternConstructor { } #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "PatternTuple"] #[type_key = "relay.PatternTuple"] pub struct PatternTupleNode { @@ -439,7 +428,7 @@ impl PatternTuple { } #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "Clause"] #[type_key = "relay.Clause"] pub struct ClauseNode { @@ -460,7 +449,7 @@ impl Clause { } #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "Match"] #[type_key = "relay.Match"] pub struct MatchNode { @@ -483,7 +472,7 @@ impl Match { } #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "Function"] #[type_key = "relay.Function"] pub struct FunctionNode { @@ -510,6 +499,20 @@ impl Function { }; Function(Some(ObjectPtr::new(node))) } + + pub fn simple(params: Vec, body: E) -> Function + where + E: IsObjectRef, + E::Object: AsRef<::Object>, + { + let params = Array::from_vec(params).unwrap(); + Self::new( + params, + body.upcast(), + Type::null(), + Array::from_vec(vec![]).unwrap(), + ) + } } #[cfg(test)] @@ -530,7 +533,7 @@ mod tests { #[test] fn test_global() -> Result<()> { - let gv = GlobalVar::new("main".to_string(), ObjectRef::null()); + let gv = GlobalVar::new("main".to_string(), Span::null()); let text = as_text(gv.clone()); assert!(text.contains("@main")); Ok(()) @@ -538,7 +541,7 @@ mod tests { #[test] fn test_var() -> Result<()> { - let var = Var::new("local".to_string(), Type::null(), ObjectRef::null()); + let var = Var::new("local".to_string(), Type::null(), Span::null()); let text = as_text(var.clone()); assert!(text.contains("%local")); Ok(()) @@ -557,7 +560,7 @@ def @main() -> float32 { ) .unwrap(); let main = module - .lookup(module.get_global_var("main".to_string().into()).unwrap()) + .lookup(module.get_global_var("main").unwrap()) .unwrap(); let func = main.downcast::().unwrap(); let constant = func diff --git a/rust/tvm/src/ir/source_map.rs b/rust/tvm/src/ir/source_map.rs index 54e16dac62ac..7376f4b74022 100644 --- a/rust/tvm/src/ir/source_map.rs +++ b/rust/tvm/src/ir/source_map.rs @@ -29,7 +29,7 @@ use tvm_macros::Object; /// /// Could represent the source from an ML framework or a source of an IRModule. #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[type_key = "Source"] #[ref_name = "Source"] pub struct SourceNode { @@ -46,7 +46,7 @@ pub struct SourceNode { /// A mapping from a unique source name to source fragments. #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[type_key = "SourceMap"] #[ref_name = "SourceMap"] pub struct SourceMapNode { diff --git a/rust/tvm/src/ir/span.rs b/rust/tvm/src/ir/span.rs index eb6821af69dc..be74745b60ca 100644 --- a/rust/tvm/src/ir/span.rs +++ b/rust/tvm/src/ir/span.rs @@ -23,7 +23,7 @@ use tvm_macros::Object; /// A source file name, contained in a Span. #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[type_key = "SourceName"] #[ref_name = "SourceName"] pub struct SourceNameNode { @@ -33,7 +33,7 @@ pub struct SourceNameNode { /// Span information for diagnostic purposes. #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[type_key = "Span"] #[ref_name = "Span"] pub struct SpanNode { diff --git a/rust/tvm/src/ir/tir.rs b/rust/tvm/src/ir/tir.rs index 22d4e02054e1..ccbe30c95820 100644 --- a/rust/tvm/src/ir/tir.rs +++ b/rust/tvm/src/ir/tir.rs @@ -26,7 +26,7 @@ use tvm_macros::Object; macro_rules! define_node { ($name:ident, $ref:expr, $typekey:expr; $node:ident { $($id:ident : $t:ty),*}) => { #[repr(C)] - #[derive(Object)] + #[derive(Object, Debug)] #[ref_name = $ref] #[type_key = $typekey] pub struct $node { @@ -47,6 +47,20 @@ macro_rules! define_node { // TODO(@jroesch): should move up to expr.rs to mirror TVM. define_node!(IntImm, "IntImm", "IntImm"; IntImmNode { value: i64 }); + +impl From for IntImm { + fn from(i: i32) -> IntImm { + IntImm::new(DataType::int(32, 1), i as i64) + } +} + +impl From for PrimExpr { + fn from(i: i32) -> PrimExpr { + use crate::runtime::IsObjectRef; + IntImm::from(i).upcast() + } +} + define_node!(Var, "Var", "tir.Var"; VarNode { name_hint: TVMString }); diff --git a/rust/tvm/src/ir/ty.rs b/rust/tvm/src/ir/ty.rs index d12f094a63ea..f7c52b51f332 100644 --- a/rust/tvm/src/ir/ty.rs +++ b/rust/tvm/src/ir/ty.rs @@ -17,15 +17,16 @@ * under the License. */ -use super::span::Span; -use crate::runtime::{IsObject, Object, ObjectPtr}; use tvm_macros::Object; use tvm_rt::{array::Array, DataType}; -use super::PrimExpr; +use crate::ir::relay::Constructor; +use crate::ir::span::Span; +use crate::ir::PrimExpr; +use crate::runtime::{string::String as TString, IsObject, Object, ObjectPtr}; #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "Type"] #[type_key = "Type"] pub struct TypeNode { @@ -51,7 +52,7 @@ impl TypeNode { * \sa PrimType */ #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "PrimType"] #[type_key = "PrimType"] pub struct PrimTypeNode { @@ -73,7 +74,7 @@ pub struct PrimTypeNode { */ #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "PointerType"] #[type_key = "PointerType"] pub struct PointerTypeNode { @@ -83,6 +84,7 @@ pub struct PointerTypeNode { } /// Possible kinds of type variables. +#[derive(PartialEq, Eq, Debug)] pub enum TypeKind { Type = 0, /// Template variable in shape expression. @@ -92,47 +94,51 @@ pub enum TypeKind { TypeData = 6, } -/* - * \brief Type parameter in functions. - * - * A type variable can be viewed as template parameter in c++ template function. - * - * For example, in the following pesudo code, - * the TypeVar of f is TypeVar("n", kind=kShapeVar). - * This function can take in a Tensor with shape=(3, 3) and - * returns a Tensor with shape=(9,) - * - * \code - * - * template - * f(x : Tensor[i32, (n, n)]) -> Tensor[i32, (n * n)] - * - * \endcode - * \sa TypeVar, TypeKind - */ +/// Type parameter in functions. +/// +/// A type variable can be viewed as template parameter in c++ template function. +/// +/// For example, in the following pesudo code, +/// the TypeVar of f is TypeVar("n", kind=kShapeVar). +/// This function can take in a Tensor with shape=(3, 3) and +/// returns a Tensor with shape=(9,) #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "TypeVar"] #[type_key = "TypeVar"] pub struct TypeVarNode { pub base: TypeNode, - pub name_hint: String, + pub name_hint: TString, pub kind: TypeKind, } /// A global type variable that is used for defining new types or type aliases. #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "GlobalTypeVar"] #[type_key = "GlobalTypeVar"] pub struct GlobalTypeVarNode { pub base: TypeNode, - pub name_hint: String, + pub name_hint: TString, pub kind: TypeKind, } +impl GlobalTypeVar { + pub fn new(name_hint: S, kind: TypeKind, span: Span) -> GlobalTypeVar + where + S: Into, + { + let node = GlobalTypeVarNode { + base: TypeNode::base::(span), + name_hint: name_hint.into(), + kind: kind, + }; + ObjectPtr::new(node).into() + } +} + #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "TupleType"] #[type_key = "TupleType"] pub struct TupleTypeNode { @@ -147,7 +153,7 @@ impl TupleType { } #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "TypeConstraint"] #[type_key = "TypeConstraint"] pub struct TypeConstraintNode { @@ -156,7 +162,7 @@ pub struct TypeConstraintNode { /// The representation of a polymorphic function type. #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "FuncType"] #[type_key = "FuncType"] pub struct FuncTypeNode { @@ -181,7 +187,7 @@ pub struct FuncTypeNode { * TypeVar represents the input to the graph. */ #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "IncompleteType"] #[type_key = "IncompleteType"] pub struct IncompleteTypeNode { @@ -195,7 +201,7 @@ pub struct IncompleteTypeNode { * \sa RelayRefType. */ #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "RefType"] #[type_key = "relay.RefType"] pub struct RelayRefTypeNode { @@ -204,7 +210,7 @@ pub struct RelayRefTypeNode { } #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "BaseTensorType"] #[type_key = "relay.BaseTensorType"] pub struct BaseTensorTypeNode { @@ -212,7 +218,7 @@ pub struct BaseTensorTypeNode { } #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "TensorType"] #[type_key = "relay.TensorType"] pub struct TensorTypeNode { @@ -240,3 +246,52 @@ impl TensorType { // using TypeRelationFn = tvm::TypeRelationFn; // using TypeReporter = tvm::TypeReporter; // using TypeReporterNode = tvm::TypeReporterNode; + +/* TypeData container node. +\brief Stores all data for an Algebraic Data Type (ADT). + +In particular, it stores the handle (global type var) for an ADT +and the constructors used to build it and is kept in the module. Note +that type parameters are also indicated in the type data: this means that +for any instance of an ADT, the type parameters must be indicated. That is, +an ADT definition is treated as a type-level function, so an ADT handle +must be wrapped in a TypeCall node that instantiates the type-level arguments. +The kind checker enforces this. */ +#[repr(C)] +#[derive(Object, Debug)] +#[ref_name = "TypeData"] +#[type_key = "relay.TypeData"] +pub struct TypeDataNode { + /// The header is simply the name of the ADT. + /// We adopt nominal typing for ADT definitions; + /// that is, differently-named ADT definitions with same constructors + /// have different types. + pub base: TypeNode, + pub type_name: GlobalTypeVar, + /// The type variables (to allow for polymorphism). + pub type_vars: Array, + /// The constructors. + pub constructors: Array, +} + +impl TypeData { + pub fn new( + type_name: GlobalTypeVar, + type_vars: TypeVars, + constructors: Ctors, + span: Span, + ) -> TypeData + where + TypeVars: IntoIterator, + Ctors: IntoIterator, + { + use std::iter::FromIterator; + let type_data = TypeDataNode { + base: TypeNode::base::(span), + type_name, + type_vars: Array::from_iter(type_vars), + constructors: Array::from_iter(constructors), + }; + TypeData(Some(ObjectPtr::new(type_data))) + } +} diff --git a/rust/tvm/src/transform.rs b/rust/tvm/src/transform.rs index c5a65c417c93..b49633777b65 100644 --- a/rust/tvm/src/transform.rs +++ b/rust/tvm/src/transform.rs @@ -33,7 +33,7 @@ pub type IRModule = ObjectRef; pub type PassContext = ObjectRef; #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "PassInfo"] #[type_key = "transform.PassInfo"] pub struct PassInfoNode { diff --git a/src/ir/module.cc b/src/ir/module.cc index b011f2d2f664..7990b281fb04 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -439,6 +439,9 @@ TVM_REGISTER_GLOBAL("ir.Module_GetGlobalTypeVars") TVM_REGISTER_GLOBAL("ir.Module_ContainGlobalVar") .set_body_method(&IRModuleNode::ContainGlobalVar); +TVM_REGISTER_GLOBAL("ir.Module_ContainGlobalTypeVar") + .set_body_method(&IRModuleNode::ContainGlobalTypeVar); + TVM_REGISTER_GLOBAL("ir.Module_GetGlobalTypeVar") .set_body_method(&IRModuleNode::GetGlobalTypeVar); diff --git a/tests/scripts/task_rust.sh b/tests/scripts/task_rust.sh index d60999c3f3d0..2c87cceec8bb 100755 --- a/tests/scripts/task_rust.sh +++ b/tests/scripts/task_rust.sh @@ -74,7 +74,7 @@ cd tests/test_tvm_dso cargo run cd - -# # run wasm32 test +# run wasm32 test # cd tests/test_wasm32 # cargo build # wasmtime $RUST_DIR/target/wasm32-wasi/debug/test-wasm32.wasm