From 57c72f5c2f45f1f97746d0ed126526304c6f0a3d Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Wed, 21 Oct 2020 14:09:37 -0700 Subject: [PATCH 01/12] WIP --- rust/tvm-macros/src/external.rs | 5 ++- rust/tvm-macros/src/lib.rs | 1 + rust/tvm/src/ir/module.rs | 67 +++++++++++---------------------- 3 files changed, 26 insertions(+), 47 deletions(-) diff --git a/rust/tvm-macros/src/external.rs b/rust/tvm-macros/src/external.rs index 802d7aeb6779..de8ada3acdda 100644 --- a/rust/tvm-macros/src/external.rs +++ b/rust/tvm-macros/src/external.rs @@ -17,6 +17,7 @@ * under the License. */ use proc_macro2::Span; +use proc_macro_error::abort; use quote::quote; use syn::parse::{Parse, ParseStream, Result}; @@ -109,7 +110,9 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> proc_macro::TokenStream { .iter() .map(|ty_param| match ty_param { syn::GenericParam::Type(param) => param.clone(), - _ => panic!(), + _ => abort! { ty_param, + "Only supports type parameters." + } }) .collect(); diff --git a/rust/tvm-macros/src/lib.rs b/rust/tvm-macros/src/lib.rs index 603e1ceaafcc..ab75c926b279 100644 --- a/rust/tvm-macros/src/lib.rs +++ b/rust/tvm-macros/src/lib.rs @@ -35,6 +35,7 @@ pub fn macro_impl(input: TokenStream) -> TokenStream { TokenStream::from(object::macro_impl(input)) } +#[proc_macro_error] #[proc_macro] pub fn external(input: TokenStream) -> TokenStream { external::macro_impl(input) diff --git a/rust/tvm/src/ir/module.rs b/rust/tvm/src/ir/module.rs index 190b477b98f2..b5538d0aa196 100644 --- a/rust/tvm/src/ir/module.rs +++ b/rust/tvm/src/ir/module.rs @@ -62,6 +62,8 @@ external! { #[name("parser.ParseExpr")] fn parse_expression(file_name: TVMString, source: TVMString) -> IRModule; // Module methods + #[name("ir.Module_Add")] + fn module_add_def(module: IRModule, type_name: GlobalVar, expr: relay::Expr, update: bool) -> (); #[name("ir.Module_AddDef")] fn module_add_def(module: IRModule, type_name: GlobalTypeVar, type_data: TypeData, update: bool) -> (); #[name("ir.Module_GetGlobalVar")] @@ -72,55 +74,28 @@ external! { fn module_lookup(module: IRModule, var: GlobalVar) -> BaseFunc; #[name("ir.Module_Lookup_str")] fn module_lookup_str(module: IRModule, name: TVMString) -> BaseFunc; + #[name("ir.Module_GetGlobalTypeVars")] + fn module_get_global_type_vars() -> Array; + #[name("ir.Module_ContainGlobalVar")] + fn module_get_global_var(name: TVMString) -> bool; + #[name("ir.Module_ContainGlobalTypeVar")] + fn module_get_global_type_var(name: TVMString) -> bool; + #[name("ir.Module_LookupDef")] + fn module_lookup_def(module: IRModule, global: GlobalTypeVar) -> TypeDef; + #[name("ir.Module_LookupDef_str")] + fn module_lookup_def_str(module: IRModule, global: GlobalTypeVar) -> TypeDef; + #[name("ir.Module_LookupTag")] + fn module_lookup_tag(module: IRModule, tag: i32) -> Constructor; + #[name("ir.Module_FromExpr")] + fn module_from_expr(expr: relay::Expr, funcs: Map, types: Map) -> IRModule; + #[name("ir.Module_Import")] + fn module_import(module: IRModule, path: TVMString); + #[name("ir.Module_ImportFromStd")] + fn module_import_from_std(module: IRModule, path: TVMString); } -// TVM_REGISTER_GLOBAL("ir.Module_GetGlobalTypeVars") -// .set_body_method(&IRModuleNode::GetGlobalTypeVars); +// Note: we don't expose update here as update is going to be removed. -// TVM_REGISTER_GLOBAL("ir.Module_ContainGlobalVar") -// .set_body_method(&IRModuleNode::ContainGlobalVar); - -// TVM_REGISTER_GLOBAL("ir.Module_GetGlobalTypeVar") -// .set_body_method(&IRModuleNode::GetGlobalTypeVar); - -// TVM_REGISTER_GLOBAL("ir.Module_LookupDef").set_body_typed([](IRModule mod, GlobalTypeVar var) { -// return mod->LookupTypeDef(var); -// }); - -// TVM_REGISTER_GLOBAL("ir.Module_LookupDef_str").set_body_typed([](IRModule mod, String var) { -// return mod->LookupTypeDef(var); -// }); - -// TVM_REGISTER_GLOBAL("ir.Module_LookupTag").set_body_typed([](IRModule mod, int32_t tag) { -// return mod->LookupTag(tag); -// }); - -// TVM_REGISTER_GLOBAL("ir.Module_FromExpr") -// .set_body_typed([](RelayExpr e, tvm::Map funcs, -// tvm::Map type_defs) { -// return IRModule::FromExpr(e, funcs, type_defs); -// }); - -// TVM_REGISTER_GLOBAL("ir.Module_Update").set_body_typed([](IRModule mod, IRModule from) { -// mod->Update(from); -// }); - -// TVM_REGISTER_GLOBAL("ir.Module_UpdateFunction") -// .set_body_typed([](IRModule mod, GlobalVar gv, BaseFunc func) { mod->Update(gv, func); }); - -// TVM_REGISTER_GLOBAL("ir.Module_Import").set_body_typed([](IRModule mod, String path) { -// mod->Import(path); -// }); - -// TVM_REGISTER_GLOBAL("ir.Module_ImportFromStd").set_body_typed([](IRModule mod, String path) { -// mod->ImportFromStd(path); -// }); - -// TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -// .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { -// auto* node = static_cast(ref.get()); -// p->stream << "IRModuleNode( " << node->functions << ")"; -// }); impl IRModule { pub fn parse(file_name: N, source: S) -> Result From 7a5869e1fa35a0c62e91ab73aaa7b600d208566d Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Thu, 22 Oct 2020 11:48:34 -0700 Subject: [PATCH 02/12] WIP --- rust/tvm-macros/Cargo.toml | 2 +- rust/tvm-macros/src/external.rs | 43 +++++++++++++++++++++++++++------ rust/tvm-macros/src/lib.rs | 1 + rust/tvm-rt/src/object/mod.rs | 2 +- rust/tvm/src/ir/module.rs | 20 ++++++++++----- 5 files changed, 52 insertions(+), 16 deletions(-) diff --git a/rust/tvm-macros/Cargo.toml b/rust/tvm-macros/Cargo.toml index 63b84727c525..8e97d3b670d8 100644 --- a/rust/tvm-macros/Cargo.toml +++ b/rust/tvm-macros/Cargo.toml @@ -33,5 +33,5 @@ proc-macro = true goblin = "^0.2" proc-macro2 = "^1.0" quote = "^1.0" -syn = { version = "1.0.17", features = ["full", "extra-traits"] } +syn = { version = "^1.0", features = ["full", "parsing", "extra-traits"] } proc-macro-error = "^1.0" diff --git a/rust/tvm-macros/src/external.rs b/rust/tvm-macros/src/external.rs index de8ada3acdda..44a242cf5e88 100644 --- a/rust/tvm-macros/src/external.rs +++ b/rust/tvm-macros/src/external.rs @@ -21,9 +21,28 @@ use proc_macro_error::abort; use quote::quote; use syn::parse::{Parse, ParseStream, Result}; -use syn::{FnArg, Generics, Ident, Lit, Meta, NestedMeta, Pat, ReturnType, TraitItemMethod, Type}; +use syn::{Token, FnArg, Signature, Attribute, token::Semi, Visibility, Generics, Ident, Lit, Meta, NestedMeta, Pat, ReturnType, TraitItemMethod, Type}; + +struct ExternalItem { + attrs: Vec, + visibility: Visibility, + sig: Signature, +} + +impl Parse for ExternalItem { + fn parse(input: ParseStream) -> Result { + let item = ExternalItem { + attrs: input.call(Attribute::parse_outer)?, + visibility: input.parse()?, + sig: input.parse()?, + }; + let _semi: Semi = input.parse()?; + Ok(item) + } +} struct External { + visibility: Visibility, tvm_name: String, ident: Ident, generics: Generics, @@ -33,7 +52,8 @@ struct External { impl Parse for External { fn parse(input: ParseStream) -> Result { - let method: TraitItemMethod = input.parse()?; + let method: ExternalItem = input.parse()?; + let visibility = method.visibility; assert_eq!(method.attrs.len(), 1); let sig = method.sig; let tvm_name = method.attrs[0].parse_meta()?; @@ -48,8 +68,7 @@ impl Parse for External { } _ => panic!(), }; - assert_eq!(method.default, None); - assert!(method.semi_token != None); + let ident = sig.ident; let generics = sig.generics; let inputs = sig @@ -61,6 +80,7 @@ impl Parse for External { let ret_type = sig.output; Ok(External { + visibility, tvm_name, ident, generics, @@ -99,6 +119,7 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> proc_macro::TokenStream { let mut items = Vec::new(); for external in &ext_input.externs { + let visibility = &external.visibility; let name = &external.ident; let global_name = format!("global_{}", external.ident); let global_name = Ident::new(&global_name, Span::call_site()); @@ -127,15 +148,21 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> proc_macro::TokenStream { let ty: Type = *pat_type.ty.clone(); (ident, ty) } - _ => panic!(), + _ => abort! { pat_type, + "Only supports type parameters." + } }, - _ => panic!(), + pat => abort! { + pat, "invalid pattern type for function"; + + note = "{:?} is not allowed here", pat; + } }) .unzip(); let ret_type = match &external.ret_type { ReturnType::Type(_, rtype) => *rtype.clone(), - _ => panic!(), + ReturnType::Default => syn::parse_str::("()").unwrap(), }; let global = quote! { @@ -150,7 +177,7 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> proc_macro::TokenStream { items.push(global); let wrapper = quote! { - pub fn #name<#(#ty_params),*>(#(#args : #tys),*) -> #result_type<#ret_type> { + #visibility fn #name<#(#ty_params),*>(#(#args : #tys),*) -> #result_type<#ret_type> { let func_ref: #tvm_rt_crate::Function = #global_name.clone(); let func_ref: Box #result_type<#ret_type>> = func_ref.into(); let res: #ret_type = func_ref(#(#args),*)?; diff --git a/rust/tvm-macros/src/lib.rs b/rust/tvm-macros/src/lib.rs index ab75c926b279..32f28394173e 100644 --- a/rust/tvm-macros/src/lib.rs +++ b/rust/tvm-macros/src/lib.rs @@ -18,6 +18,7 @@ */ use proc_macro::TokenStream; +use proc_macro_error::proc_macro_error; mod external; mod import_module; diff --git a/rust/tvm-rt/src/object/mod.rs b/rust/tvm-rt/src/object/mod.rs index 46e034232a63..e48c0173e4e4 100644 --- a/rust/tvm-rt/src/object/mod.rs +++ b/rust/tvm-rt/src/object/mod.rs @@ -88,7 +88,7 @@ pub trait IsObjectRef: external! { #[name("ir.DebugPrint")] - fn debug_print(object: ObjectRef) -> CString; + pub fn debug_print(object: ObjectRef) -> CString; #[name("node.StructuralHash")] fn structural_hash(object: ObjectRef, map_free_vars: bool) -> ObjectRef; #[name("node.StructuralEqual")] diff --git a/rust/tvm/src/ir/module.rs b/rust/tvm/src/ir/module.rs index b5538d0aa196..d589c5338e30 100644 --- a/rust/tvm/src/ir/module.rs +++ b/rust/tvm/src/ir/module.rs @@ -21,6 +21,9 @@ use std::path::Path; use thiserror::Error; use tvm_macros::Object; +use std::io::Result as IOResult; +use std::path::Path; + use crate::runtime::array::Array; use crate::runtime::function::Result; use crate::runtime::map::Map; @@ -30,10 +33,9 @@ use crate::runtime::{external, Object, ObjectRef}; use super::expr::GlobalVar; use super::function::BaseFunc; use super::source_map::SourceMap; +use super::{ty::GlobalTypeVar, relay}; -// TODO(@jroesch): define type type TypeData = ObjectRef; -type GlobalTypeVar = ObjectRef; #[derive(Error, Debug)] pub enum Error { @@ -63,7 +65,7 @@ external! { fn parse_expression(file_name: TVMString, source: TVMString) -> IRModule; // Module methods #[name("ir.Module_Add")] - fn module_add_def(module: IRModule, type_name: GlobalVar, expr: relay::Expr, update: bool) -> (); + fn module_add(module: IRModule, type_name: GlobalVar, expr: relay::Expr, update: bool) -> (); #[name("ir.Module_AddDef")] fn module_add_def(module: IRModule, type_name: GlobalTypeVar, type_data: TypeData, update: bool) -> (); #[name("ir.Module_GetGlobalVar")] @@ -77,15 +79,15 @@ external! { #[name("ir.Module_GetGlobalTypeVars")] fn module_get_global_type_vars() -> Array; #[name("ir.Module_ContainGlobalVar")] - fn module_get_global_var(name: TVMString) -> bool; + fn module_contains_global_var(name: TVMString) -> bool; #[name("ir.Module_ContainGlobalTypeVar")] - fn module_get_global_type_var(name: TVMString) -> bool; + fn module_contains_global_type_var(name: TVMString) -> bool; #[name("ir.Module_LookupDef")] fn module_lookup_def(module: IRModule, global: GlobalTypeVar) -> TypeDef; #[name("ir.Module_LookupDef_str")] fn module_lookup_def_str(module: IRModule, global: GlobalTypeVar) -> TypeDef; #[name("ir.Module_LookupTag")] - fn module_lookup_tag(module: IRModule, tag: i32) -> Constructor; + fn module_lookup_tag(module: IRModule, tag: i32) -> relay::Constructor; #[name("ir.Module_FromExpr")] fn module_from_expr(expr: relay::Expr, funcs: Map, types: Map) -> IRModule; #[name("ir.Module_Import")] @@ -144,3 +146,9 @@ impl IRModule { module_lookup_str(self.clone(), name.into()) } } + +#[cfg(test)] +mod tests { + // #[test] + // fn +} From b2d0095bedc2684019816c94060175c45f880f13 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Thu, 22 Oct 2020 22:18:58 -0700 Subject: [PATCH 03/12] WIP --- rust/tvm-macros/src/external.rs | 2 +- rust/tvm-macros/src/lib.rs | 3 +- rust/tvm-macros/src/object.rs | 23 ++++ rust/tvm-rt/src/object/mod.rs | 9 +- rust/tvm-rt/src/object/object_ptr.rs | 16 +++ rust/tvm-rt/src/string.rs | 1 + rust/tvm-rt/src/value.rs | 1 - rust/tvm-sys/src/datatype.rs | 4 + rust/tvm/src/ir/module.rs | 159 +++++++++++++++++++++++++-- rust/tvm/src/ir/relay/mod.rs | 36 +++--- rust/tvm/src/ir/tir.rs | 14 +++ 11 files changed, 226 insertions(+), 42 deletions(-) diff --git a/rust/tvm-macros/src/external.rs b/rust/tvm-macros/src/external.rs index 44a242cf5e88..51a389b2f2f2 100644 --- a/rust/tvm-macros/src/external.rs +++ b/rust/tvm-macros/src/external.rs @@ -21,7 +21,7 @@ use proc_macro_error::abort; use quote::quote; use syn::parse::{Parse, ParseStream, Result}; -use syn::{Token, FnArg, Signature, Attribute, token::Semi, Visibility, Generics, Ident, Lit, Meta, NestedMeta, Pat, ReturnType, TraitItemMethod, Type}; +use syn::{FnArg, Signature, Attribute, token::Semi, Visibility, Generics, Ident, Lit, Meta, NestedMeta, Pat, ReturnType, Type}; struct ExternalItem { attrs: Vec, diff --git a/rust/tvm-macros/src/lib.rs b/rust/tvm-macros/src/lib.rs index 32f28394173e..e563a57f149e 100644 --- a/rust/tvm-macros/src/lib.rs +++ b/rust/tvm-macros/src/lib.rs @@ -30,7 +30,8 @@ pub fn import_module(input: TokenStream) -> TokenStream { import_module::macro_impl(input) } -#[proc_macro_derive(Object, attributes(base, ref_name, type_key))] +#[proc_macro_error] +#[proc_macro_derive(Object, attributes(base, ref_name, type_key, no_derive))] pub fn macro_impl(input: TokenStream) -> TokenStream { // let input = proc_macro2::TokenStream::from(input); TokenStream::from(object::macro_impl(input)) diff --git a/rust/tvm-macros/src/object.rs b/rust/tvm-macros/src/object.rs index ff72d6a649be..7e6a9343da89 100644 --- a/rust/tvm-macros/src/object.rs +++ b/rust/tvm-macros/src/object.rs @@ -36,6 +36,8 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream { .map(attr_to_str) .expect("Failed to get type_key"); + let derive = get_attr(&derive_input, "no_derive").map(|_| false).unwrap_or(true); + let ref_id = get_attr(&derive_input, "ref_name") .map(|a| Ident::new(attr_to_str(a).value().as_str(), Span::call_site())) .unwrap_or_else(|| { @@ -185,5 +187,26 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream { expanded.extend(base_tokens); + if derive { + let derives = quote! { + impl std::hash::Hash for #ref_id { + fn hash(&self, state: &mut H) { + self.0.hash(state) + } + } + + impl std::cmp::PartialEq for #ref_id { + fn eq(&self, other: &Self) -> bool { + self.0 == other.0 + } + } + + impl std::cmp::Eq for #ref_id {} + }; + + + expanded.extend(derives); + } + TokenStream::from(expanded) } diff --git a/rust/tvm-rt/src/object/mod.rs b/rust/tvm-rt/src/object/mod.rs index e48c0173e4e4..7e6107d7cbd0 100644 --- a/rust/tvm-rt/src/object/mod.rs +++ b/rust/tvm-rt/src/object/mod.rs @@ -90,12 +90,7 @@ external! { #[name("ir.DebugPrint")] pub fn debug_print(object: ObjectRef) -> CString; #[name("node.StructuralHash")] - fn structural_hash(object: ObjectRef, map_free_vars: bool) -> ObjectRef; + fn structural_hash(object: ObjectRef, map_free_vars: bool) -> i64; #[name("node.StructuralEqual")] - fn structural_equal(lhs: ObjectRef, rhs: ObjectRef, assert_mode: bool, map_free_vars: bool) -> ObjectRef; + fn structural_equal(lhs: ObjectRef, rhs: ObjectRef, assert_mode: bool, map_free_vars: bool) -> bool; } - -// external! { -// #[name("ir.TextPrinter")] -// fn as_text(object: ObjectRef) -> CString; -// } diff --git a/rust/tvm-rt/src/object/object_ptr.rs b/rust/tvm-rt/src/object/object_ptr.rs index 8d535368c352..0f61bd228988 100644 --- a/rust/tvm-rt/src/object/object_ptr.rs +++ b/rust/tvm-rt/src/object/object_ptr.rs @@ -342,6 +342,22 @@ impl<'a, T: IsObject> TryFrom> for ObjectPtr { } } +impl std::hash::Hash for ObjectPtr { + fn hash(&self, state: &mut H) { + state.write_i64(super::structural_hash(ObjectRef(Some(self.clone().upcast())), false).unwrap()) + } +} + +impl PartialEq for ObjectPtr { + fn eq(&self, other: &Self) -> bool { + let lhs = ObjectRef(Some(self.clone().upcast())); + let rhs = ObjectRef(Some(other.clone().upcast())); + super::structural_equal(lhs, rhs, false, false).unwrap() + } +} + +impl Eq for ObjectPtr {} + #[cfg(test)] mod tests { use super::{Object, ObjectPtr}; diff --git a/rust/tvm-rt/src/string.rs b/rust/tvm-rt/src/string.rs index 3cd33a226d44..a650cc994d98 100644 --- a/rust/tvm-rt/src/string.rs +++ b/rust/tvm-rt/src/string.rs @@ -28,6 +28,7 @@ use tvm_macros::Object; #[derive(Object)] #[ref_name = "String"] #[type_key = "runtime.String"] +#[no_derive] pub struct StringObj { base: Object, data: *const u8, diff --git a/rust/tvm-rt/src/value.rs b/rust/tvm-rt/src/value.rs index c49944dc7e33..b8cd190176c4 100644 --- a/rust/tvm-rt/src/value.rs +++ b/rust/tvm-rt/src/value.rs @@ -22,7 +22,6 @@ //! `RetValue` is the owned version of `TVMPODValue`. use std::convert::TryFrom; -// use std::ffi::c_void; use crate::{ArgValue, Module, RetValue}; use tvm_sys::{errors::ValueDowncastError, ffi::TVMModuleHandle, try_downcast}; diff --git a/rust/tvm-sys/src/datatype.rs b/rust/tvm-sys/src/datatype.rs index 8050d932e5c1..5f7e0c3a3b60 100644 --- a/rust/tvm-sys/src/datatype.rs +++ b/rust/tvm-sys/src/datatype.rs @@ -83,6 +83,10 @@ impl DataType { DataType::new(DL_FLOAT_CODE, bits, lanes) } + pub const fn float32() -> DataType { + Self::float(32, 1) + } + pub const fn uint(bits: u8, lanes: u16) -> DataType { DataType::new(DL_UINT_CODE, bits, lanes) } diff --git a/rust/tvm/src/ir/module.rs b/rust/tvm/src/ir/module.rs index d589c5338e30..04115ec57bd6 100644 --- a/rust/tvm/src/ir/module.rs +++ b/rust/tvm/src/ir/module.rs @@ -22,6 +22,7 @@ use thiserror::Error; use tvm_macros::Object; use std::io::Result as IOResult; +use std::iter::FromIterator; use std::path::Path; use crate::runtime::array::Array; @@ -35,6 +36,12 @@ use super::function::BaseFunc; use super::source_map::SourceMap; use super::{ty::GlobalTypeVar, relay}; +<<<<<<< HEAD +======= +use tvm_macros::Object; + +// TODO(@jroesch): define type +>>>>>>> WIP type TypeData = ObjectRef; #[derive(Error, Debug)] @@ -63,9 +70,11 @@ external! { fn parse_module(file_name: TVMString, source: TVMString) -> IRModule; #[name("parser.ParseExpr")] fn parse_expression(file_name: TVMString, source: TVMString) -> IRModule; + #[name("ir.IRModule")] + fn module_new(funcs: Map, types: Map) -> IRModule; // Module methods #[name("ir.Module_Add")] - fn module_add(module: IRModule, type_name: GlobalVar, expr: relay::Expr, update: bool) -> (); + fn module_add(module: IRModule, type_name: GlobalVar, expr: BaseFunc, update: bool) -> IRModule; #[name("ir.Module_AddDef")] fn module_add_def(module: IRModule, type_name: GlobalTypeVar, type_data: TypeData, update: bool) -> (); #[name("ir.Module_GetGlobalVar")] @@ -77,15 +86,15 @@ external! { #[name("ir.Module_Lookup_str")] fn module_lookup_str(module: IRModule, name: TVMString) -> BaseFunc; #[name("ir.Module_GetGlobalTypeVars")] - fn module_get_global_type_vars() -> Array; + fn module_get_global_type_vars(module: IRModule) -> Array; #[name("ir.Module_ContainGlobalVar")] - fn module_contains_global_var(name: TVMString) -> bool; + fn module_contains_global_var(module: IRModule, name: TVMString) -> bool; #[name("ir.Module_ContainGlobalTypeVar")] - fn module_contains_global_type_var(name: TVMString) -> bool; + fn module_contains_global_type_var(module: IRModule, name: TVMString) -> bool; #[name("ir.Module_LookupDef")] - fn module_lookup_def(module: IRModule, global: GlobalTypeVar) -> TypeDef; + fn module_lookup_def(module: IRModule, global: GlobalTypeVar) -> TypeData; #[name("ir.Module_LookupDef_str")] - fn module_lookup_def_str(module: IRModule, global: GlobalTypeVar) -> TypeDef; + fn module_lookup_def_str(module: IRModule, global: GlobalTypeVar) -> TypeData; #[name("ir.Module_LookupTag")] fn module_lookup_tag(module: IRModule, tag: i32) -> relay::Constructor; #[name("ir.Module_FromExpr")] @@ -98,9 +107,17 @@ external! { // Note: we don't expose update here as update is going to be removed. - impl IRModule { +<<<<<<< HEAD pub fn parse(file_name: N, source: S) -> Result +======= + pub fn new(funcs: F, types: T) -> Result + where F: IntoIterator, T: IntoIterator { + module_new(Map::from_iter(funcs), Map::from_iter(types)) + } + + pub fn parse(file_name: N, source: S) -> IRModule +>>>>>>> WIP where N: Into, S: Into, @@ -118,6 +135,13 @@ impl IRModule { Ok(module) } + pub fn add( + &mut self, + var: GlobalVar, + func: BaseFunc) -> Result { + module_add(self.clone(), var, func, true) + } + pub fn add_def( &mut self, type_name: GlobalTypeVar, @@ -145,10 +169,127 @@ impl IRModule { { module_lookup_str(self.clone(), name.into()) } + + pub fn get_global_type_vars(&self) -> Result> { + module_get_global_type_vars(self.clone()) + } + + pub fn contains_global_var>(&self, name: S) -> Result { + module_contains_global_var(self.clone(), name.into()) + } + + pub fn contains_global_type_var>(&self, name: S) -> Result { + module_contains_global_type_var(self.clone(), name.into()) + } + + pub fn lookup_def(&self, global: GlobalTypeVar) -> Result { + module_lookup_def(self.clone(), global) + } + + pub fn lookup_def_str(&self, global: GlobalTypeVar) -> Result { + module_lookup_def_str(self.clone(), global) + } + + pub fn lookup_tag(&self, tag: i32) -> Result { + module_lookup_tag(self.clone(), tag) + } + + pub fn from_expr(expr: relay::Expr, funcs: Map, types: Map) -> Result { + module_from_expr(expr, funcs, types) + } + + pub fn import>(&mut self, path: S) -> Result<()> { + module_import(self.clone(), path.into()) + } + + pub fn import_from_std>(&mut self, path: S) -> Result<()> { + module_import_from_std(self.clone(), path.into()) + } } #[cfg(test)] mod tests { - // #[test] - // fn + use std::collections::HashMap; + use super::relay::*; + use super::*; + use super::super::span::Span; + use tvm_rt::IsObjectRef; + + #[test] + fn test_module_add() -> anyhow::Result<()> { + let funcs = HashMap::::new(); + let types = HashMap::::new(); + let mut module = IRModule::new(funcs, types)?; + let x = Var::static_tensor("x".into(), vec![1, 1], DataType::float32()); + let params = Array::from_vec(vec![x.clone()])?; + let func = relay::Function::simple(params, x.upcast()).upcast(); + let module = module.add(GlobalVar::new("foo".into(), Span::null()), func)?; + // let lfunc = module.lookup_str("foo")?; + // let lfunc = lfunc.downcast::()?; + // assert_eq!(lfunc.params.len(), 1); + Ok(()) + } + + #[test] + fn test_module_add_def() { + + } + + #[test] + fn test_get_global_var() { + + } + + #[test] + fn test_get_global_vars() { + + } + + #[test] + fn test_lookup() { + + } + + + // pub fn get_global_type_vars(&self) -> Result> { + // module_get_global_type_vars(self.clone()) + // } + + // pub fn contains_global_var>(&self, name: S) -> Result { + // module_contains_global_var(self.clone(), name.into()) + // } + + // pub fn contains_global_type_var>(&self, name: S) -> Result { + // module_contains_global_type_var(self.clone(), name.into()) + // } + + #[test] + fn test_lookup_def() { + + } + // pub fn lookup_def(&self, global: GlobalTypeVar) -> Result { + // module_lookup_def(self.clone(), global) + // } + + // pub fn lookup_def_str(&self, global: GlobalTypeVar) -> Result { + // module_lookup_def_str(self.clone(), global) + // } + + // pub fn lookup_tag(&self, tag: i32) -> Result { + // module_lookup_tag(self.clone(), tag) + // } + + // pub fn from_expr(expr: relay::Expr, funcs: Map, types: Map) -> Result { + // module_from_expr(expr, funcs, types) + // } + + + // pub fn import>(&mut self, path: S) -> Result<()> { + // module_import(self.clone(), path.into()) + // } + + + // pub fn import_from_std>(&mut self, path: S) -> Result<()> { + // module_import_from_std(self.clone(), path.into()) + // } } diff --git a/rust/tvm/src/ir/relay/mod.rs b/rust/tvm/src/ir/relay/mod.rs index cc1a76bef7e3..c4d46acee8a7 100644 --- a/rust/tvm/src/ir/relay/mod.rs +++ b/rust/tvm/src/ir/relay/mod.rs @@ -16,11 +16,6 @@ * specific language governing permissions and limitations * under the License. */ - -pub mod attrs; - -use std::hash::Hash; - use crate::runtime::array::Array; use crate::runtime::{object::*, IsObjectRef, String as TString}; @@ -29,11 +24,15 @@ use super::expr::BaseExprNode; use super::function::BaseFuncNode; use super::span::Span; use super::ty::{Type, TypeNode}; +use super::span::Span; use tvm_macros::Object; use tvm_rt::NDArray; pub use super::expr::{GlobalVar, GlobalVarNode}; +pub use crate::runtime::DataType; + +pub mod attrs; #[repr(C)] #[derive(Object)] @@ -58,20 +57,6 @@ impl ExprNode { } } -impl Hash for Expr { - fn hash(&self, state: &mut H) { - self.as_ptr().unwrap().ptr.hash(state) - } -} - -impl PartialEq for Expr { - fn eq(&self, other: &Self) -> bool { - self.as_ptr().unwrap().ptr.eq(&other.as_ptr().unwrap().ptr) - } -} - -impl Eq for Expr {} - #[repr(C)] #[derive(Object)] #[ref_name = "Id"] @@ -140,11 +125,11 @@ pub struct VarNode { } impl Var { - pub fn new(name_hint: String, type_annotation: Type, _span: ObjectRef) -> Var { + pub fn new(name_hint: String, type_annotation: Type, _span: Span) -> Var { let node = VarNode { base: ExprNode::base::(), vid: Id::new(name_hint.into()), - type_annotation, + type_annotation: type_annotation, }; Var(Some(ObjectPtr::new(node))) } @@ -153,8 +138,9 @@ impl Var { &self.vid.0.as_ref().unwrap().name_hint } - pub fn to_expr(self) -> Expr { - unsafe { Expr(std::mem::transmute(self.0)) } + pub fn static_tensor(name_hint: String, sh: Vec, dtype: DataType) -> Var { + let sh = Array::from_vec(sh.into_iter().map(Into::into).collect()).unwrap(); + Self::new(name_hint, super::ty::TensorType::new(sh, dtype, Span::null()).upcast(), Span::null()) } } @@ -510,6 +496,10 @@ impl Function { }; Function(Some(ObjectPtr::new(node))) } + + pub fn simple(params: Array, body: Expr) -> Function { + Self::new(params, body, Type::null(), Array::from_vec(vec![]).unwrap()) + } } #[cfg(test)] diff --git a/rust/tvm/src/ir/tir.rs b/rust/tvm/src/ir/tir.rs index 22d4e02054e1..f07e85486626 100644 --- a/rust/tvm/src/ir/tir.rs +++ b/rust/tvm/src/ir/tir.rs @@ -47,6 +47,20 @@ macro_rules! define_node { // TODO(@jroesch): should move up to expr.rs to mirror TVM. define_node!(IntImm, "IntImm", "IntImm"; IntImmNode { value: i64 }); + +impl From for IntImm { + fn from(i: i32) -> IntImm { + IntImm::new(DataType::int(32, 1), i as i64) + } +} + +impl From for PrimExpr { + fn from(i: i32) -> PrimExpr { + use crate::runtime::IsObjectRef; + IntImm::from(i).upcast() + } +} + define_node!(Var, "Var", "tir.Var"; VarNode { name_hint: TVMString }); From b78453474d9a0574db669cece5d35719e173320c Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Sun, 1 Nov 2020 17:07:05 -0800 Subject: [PATCH 04/12] WIP --- rust/tvm-graph-rt/tests/test_wasm32/Cargo.toml | 2 +- rust/tvm/src/ir/module.rs | 11 +---------- rust/tvm/src/ir/relay/mod.rs | 1 - tests/scripts/task_rust.sh | 10 +++++----- 4 files changed, 7 insertions(+), 17 deletions(-) diff --git a/rust/tvm-graph-rt/tests/test_wasm32/Cargo.toml b/rust/tvm-graph-rt/tests/test_wasm32/Cargo.toml index aed467f1235d..02e77d106f28 100644 --- a/rust/tvm-graph-rt/tests/test_wasm32/Cargo.toml +++ b/rust/tvm-graph-rt/tests/test_wasm32/Cargo.toml @@ -23,7 +23,7 @@ authors = ["TVM Contributors"] edition = "2018" [dependencies] -ndarray="0.12" +ndarray = "0.12" tvm-graph-rt = { path = "../../" } [build-dependencies] diff --git a/rust/tvm/src/ir/module.rs b/rust/tvm/src/ir/module.rs index 04115ec57bd6..21119045edd4 100644 --- a/rust/tvm/src/ir/module.rs +++ b/rust/tvm/src/ir/module.rs @@ -36,12 +36,7 @@ use super::function::BaseFunc; use super::source_map::SourceMap; use super::{ty::GlobalTypeVar, relay}; -<<<<<<< HEAD -======= -use tvm_macros::Object; - // TODO(@jroesch): define type ->>>>>>> WIP type TypeData = ObjectRef; #[derive(Error, Debug)] @@ -108,16 +103,12 @@ external! { // Note: we don't expose update here as update is going to be removed. impl IRModule { -<<<<<<< HEAD - pub fn parse(file_name: N, source: S) -> Result -======= pub fn new(funcs: F, types: T) -> Result where F: IntoIterator, T: IntoIterator { module_new(Map::from_iter(funcs), Map::from_iter(types)) } - pub fn parse(file_name: N, source: S) -> IRModule ->>>>>>> WIP + pub fn parse(file_name: N, source: S) -> Result where N: Into, S: Into, diff --git a/rust/tvm/src/ir/relay/mod.rs b/rust/tvm/src/ir/relay/mod.rs index c4d46acee8a7..f0fca3d5bb67 100644 --- a/rust/tvm/src/ir/relay/mod.rs +++ b/rust/tvm/src/ir/relay/mod.rs @@ -24,7 +24,6 @@ use super::expr::BaseExprNode; use super::function::BaseFuncNode; use super::span::Span; use super::ty::{Type, TypeNode}; -use super::span::Span; use tvm_macros::Object; use tvm_rt::NDArray; diff --git a/tests/scripts/task_rust.sh b/tests/scripts/task_rust.sh index d60999c3f3d0..6ed4df8967b3 100755 --- a/tests/scripts/task_rust.sh +++ b/tests/scripts/task_rust.sh @@ -74,11 +74,11 @@ cd tests/test_tvm_dso cargo run cd - -# # run wasm32 test -# cd tests/test_wasm32 -# cargo build -# wasmtime $RUST_DIR/target/wasm32-wasi/debug/test-wasm32.wasm -# cd - +# run wasm32 test +cd tests/test_wasm32 +cargo build +wasmtime $RUST_DIR/target/wasm32-wasi/debug/test-wasm32.wasm +cd - # run nn graph test cd tests/test_nn From a39a1bf42ffaa181cd8b8817ff067bd28b785343 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Sun, 1 Nov 2020 18:12:46 -0800 Subject: [PATCH 05/12] Disable WASM and fix rebase --- rust/Cargo.toml | 1 - rust/tvm/src/ir/module.rs | 6 ++---- tests/scripts/task_rust.sh | 8 ++++---- 3 files changed, 6 insertions(+), 9 deletions(-) diff --git a/rust/Cargo.toml b/rust/Cargo.toml index 7c092d860b50..e75150859f90 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -27,7 +27,6 @@ members = [ "tvm-graph-rt", "tvm-graph-rt/tests/test_tvm_basic", "tvm-graph-rt/tests/test_tvm_dso", - "tvm-graph-rt/tests/test_wasm32", "tvm-graph-rt/tests/test_nn", "compiler-ext", ] diff --git a/rust/tvm/src/ir/module.rs b/rust/tvm/src/ir/module.rs index 21119045edd4..ff46a9f71199 100644 --- a/rust/tvm/src/ir/module.rs +++ b/rust/tvm/src/ir/module.rs @@ -16,15 +16,13 @@ * specific language governing permissions and limitations * under the License. */ + +use std::iter::FromIterator; use std::path::Path; use thiserror::Error; use tvm_macros::Object; -use std::io::Result as IOResult; -use std::iter::FromIterator; -use std::path::Path; - use crate::runtime::array::Array; use crate::runtime::function::Result; use crate::runtime::map::Map; diff --git a/tests/scripts/task_rust.sh b/tests/scripts/task_rust.sh index 6ed4df8967b3..2c87cceec8bb 100755 --- a/tests/scripts/task_rust.sh +++ b/tests/scripts/task_rust.sh @@ -75,10 +75,10 @@ cargo run cd - # run wasm32 test -cd tests/test_wasm32 -cargo build -wasmtime $RUST_DIR/target/wasm32-wasi/debug/test-wasm32.wasm -cd - +# cd tests/test_wasm32 +# cargo build +# wasmtime $RUST_DIR/target/wasm32-wasi/debug/test-wasm32.wasm +# cd - # run nn graph test cd tests/test_nn From 3e8cd8e59a150beb155d4841131739ca88d3d054 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Sun, 1 Nov 2020 18:31:33 -0800 Subject: [PATCH 06/12] Work on finishing tests --- rust/tvm/src/ir/expr.rs | 9 +++++--- rust/tvm/src/ir/module.rs | 40 ++++++++++++++++-------------------- rust/tvm/src/ir/relay/mod.rs | 4 ++-- rust/tvm/src/ir/ty.rs | 30 +++++++++++++++++++++++++++ 4 files changed, 56 insertions(+), 27 deletions(-) diff --git a/rust/tvm/src/ir/expr.rs b/rust/tvm/src/ir/expr.rs index f74522d91c70..4ce3d790dcae 100644 --- a/rust/tvm/src/ir/expr.rs +++ b/rust/tvm/src/ir/expr.rs @@ -17,12 +17,15 @@ * under the License. */ -use super::relay; +use tvm_macros::Object; + use crate::runtime::String as TString; use crate::runtime::{self, external, IsObject, IsObjectRef, Object, ObjectPtr, ObjectRef}; use crate::DataType; -use tvm_macros::Object; +use super::relay; +use super::span::Span; + #[repr(C)] #[derive(Object)] @@ -68,7 +71,7 @@ pub struct GlobalVarNode { } impl GlobalVar { - pub fn new(name_hint: String, _span: ObjectRef) -> GlobalVar { + pub fn new(name_hint: String, _span: Span) -> GlobalVar { let node = GlobalVarNode { base: relay::ExprNode::base::(), name_hint: name_hint.into(), diff --git a/rust/tvm/src/ir/module.rs b/rust/tvm/src/ir/module.rs index ff46a9f71199..64679e04b3be 100644 --- a/rust/tvm/src/ir/module.rs +++ b/rust/tvm/src/ir/module.rs @@ -213,15 +213,15 @@ mod tests { let params = Array::from_vec(vec![x.clone()])?; let func = relay::Function::simple(params, x.upcast()).upcast(); let module = module.add(GlobalVar::new("foo".into(), Span::null()), func)?; - // let lfunc = module.lookup_str("foo")?; - // let lfunc = lfunc.downcast::()?; - // assert_eq!(lfunc.params.len(), 1); + let lfunc = module.lookup_str("foo")?; + let lfunc = lfunc.downcast::()?; + assert_eq!(lfunc.params.len(), 1); Ok(()) } #[test] fn test_module_add_def() { - + todo!("this is blocked on having ability to define ADTs") } #[test] @@ -235,34 +235,32 @@ mod tests { } #[test] - fn test_lookup() { + fn test_get_global_type_vars() { } + #[test] + fn test_lookup() { - // pub fn get_global_type_vars(&self) -> Result> { - // module_get_global_type_vars(self.clone()) - // } + } - // pub fn contains_global_var>(&self, name: S) -> Result { - // module_contains_global_var(self.clone(), name.into()) - // } + #[test] + fn test_contains_global_var() { + } - // pub fn contains_global_type_var>(&self, name: S) -> Result { - // module_contains_global_type_var(self.clone(), name.into()) - // } + #[test] + fn test_contains_global_type_var() { + } #[test] fn test_lookup_def() { } - // pub fn lookup_def(&self, global: GlobalTypeVar) -> Result { - // module_lookup_def(self.clone(), global) - // } - // pub fn lookup_def_str(&self, global: GlobalTypeVar) -> Result { - // module_lookup_def_str(self.clone(), global) - // } + #[test] + fn lookup_def() { + + } // pub fn lookup_tag(&self, tag: i32) -> Result { // module_lookup_tag(self.clone(), tag) @@ -272,12 +270,10 @@ mod tests { // module_from_expr(expr, funcs, types) // } - // pub fn import>(&mut self, path: S) -> Result<()> { // module_import(self.clone(), path.into()) // } - // pub fn import_from_std>(&mut self, path: S) -> Result<()> { // module_import_from_std(self.clone(), path.into()) // } diff --git a/rust/tvm/src/ir/relay/mod.rs b/rust/tvm/src/ir/relay/mod.rs index f0fca3d5bb67..a89a1d6e0f9a 100644 --- a/rust/tvm/src/ir/relay/mod.rs +++ b/rust/tvm/src/ir/relay/mod.rs @@ -519,7 +519,7 @@ mod tests { #[test] fn test_global() -> Result<()> { - let gv = GlobalVar::new("main".to_string(), ObjectRef::null()); + let gv = GlobalVar::new("main".to_string(), Span::null()); let text = as_text(gv.clone()); assert!(text.contains("@main")); Ok(()) @@ -527,7 +527,7 @@ mod tests { #[test] fn test_var() -> Result<()> { - let var = Var::new("local".to_string(), Type::null(), ObjectRef::null()); + let var = Var::new("local".to_string(), Type::null(), Span::null()); let text = as_text(var.clone()); assert!(text.contains("%local")); Ok(()) diff --git a/rust/tvm/src/ir/ty.rs b/rust/tvm/src/ir/ty.rs index d12f094a63ea..57f2a790b7d5 100644 --- a/rust/tvm/src/ir/ty.rs +++ b/rust/tvm/src/ir/ty.rs @@ -23,6 +23,7 @@ use tvm_macros::Object; use tvm_rt::{array::Array, DataType}; use super::PrimExpr; +use super::relay::Constructor; #[repr(C)] #[derive(Object)] @@ -240,3 +241,32 @@ impl TensorType { // using TypeRelationFn = tvm::TypeRelationFn; // using TypeReporter = tvm::TypeReporter; // using TypeReporterNode = tvm::TypeReporterNode; + +/* TypeData container node. +\brief Stores all data for an Algebraic Data Type (ADT). + +In particular, it stores the handle (global type var) for an ADT +and the constructors used to build it and is kept in the module. Note +that type parameters are also indicated in the type data: this means that +for any instance of an ADT, the type parameters must be indicated. That is, +an ADT definition is treated as a type-level function, so an ADT handle +must be wrapped in a TypeCall node that instantiates the type-level arguments. +The kind checker enforces this. */ +#[repr(C)] +#[derive(Object)] +#[ref_name = "TypeData"] +#[type_key = "relay.TypeData"] +pub struct TypeDataNode { + // /*! + // * \brief The header is simply the name of the ADT. + // * We adopt nominal typing for ADT definitions; + // * that is, differently-named ADT definitions with same constructors + // * have different types. + // */ + pub base: Object, + pub type_name: GlobalTypeVar, + /// The type variables (to allow for polymorphism). + pub type_vars: Array, + /// The constructors. + pub constructors: Array, +} From 1b72ca0979728b9bcda37f8ef97e9fea80f84b53 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Wed, 4 Nov 2020 14:48:24 -0800 Subject: [PATCH 07/12] Make entire object system printable --- rust/tvm-macros/src/object.rs | 8 +++- rust/tvm-rt/src/array.rs | 7 +++ rust/tvm-rt/src/ndarray.rs | 2 +- rust/tvm-rt/src/object/mod.rs | 1 + rust/tvm-rt/src/object/object_ptr.rs | 22 +++++++++- rust/tvm-rt/src/string.rs | 2 +- rust/tvm/src/ir/arith.rs | 2 +- rust/tvm/src/ir/attrs.rs | 2 +- rust/tvm/src/ir/diagnostics/mod.rs | 7 +-- rust/tvm/src/ir/expr.rs | 6 +-- rust/tvm/src/ir/function.rs | 2 +- rust/tvm/src/ir/module.rs | 30 +++++++++---- rust/tvm/src/ir/op.rs | 2 +- rust/tvm/src/ir/relay/attrs/nn.rs | 14 +++--- rust/tvm/src/ir/relay/attrs/transform.rs | 2 +- rust/tvm/src/ir/relay/mod.rs | 45 ++++++++++---------- rust/tvm/src/ir/source_map.rs | 4 +- rust/tvm/src/ir/span.rs | 4 +- rust/tvm/src/ir/tir.rs | 2 +- rust/tvm/src/ir/ty.rs | 54 ++++++++++-------------- rust/tvm/src/transform.rs | 2 +- 21 files changed, 131 insertions(+), 89 deletions(-) diff --git a/rust/tvm-macros/src/object.rs b/rust/tvm-macros/src/object.rs index 7e6a9343da89..50793a9988aa 100644 --- a/rust/tvm-macros/src/object.rs +++ b/rust/tvm-macros/src/object.rs @@ -77,6 +77,12 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream { _ => panic!("derive only works for structs"), }; + let ref_derives = if derive { + quote! { #[derive(Debug, Clone)]} + } else { + quote! { #[derive(Clone)] } + }; + let mut expanded = quote! { unsafe impl #tvm_rt_crate::object::IsObject for #payload_id { const TYPE_KEY: &'static str = #type_key; @@ -89,7 +95,7 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream { } } - #[derive(Clone)] + #ref_derives pub struct #ref_id(Option<#tvm_rt_crate::object::ObjectPtr<#payload_id>>); impl #tvm_rt_crate::object::IsObjectRef for #ref_id { diff --git a/rust/tvm-rt/src/array.rs b/rust/tvm-rt/src/array.rs index 98414f9c5b34..1850cebbb4ec 100644 --- a/rust/tvm-rt/src/array.rs +++ b/rust/tvm-rt/src/array.rs @@ -82,6 +82,13 @@ impl Array { } } +impl std::fmt::Debug for Array { + fn fmt(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + let as_vec: Vec = self.clone().into_iter().collect(); + write!(formatter, "{:?}", as_vec) + } +} + pub struct IntoIter { array: Array, pos: isize, diff --git a/rust/tvm-rt/src/ndarray.rs b/rust/tvm-rt/src/ndarray.rs index ed280ccc2d80..07f783f0ef43 100644 --- a/rust/tvm-rt/src/ndarray.rs +++ b/rust/tvm-rt/src/ndarray.rs @@ -65,7 +65,7 @@ use crate::object::{Object, ObjectPtr}; /// See the [`module-level documentation`](../ndarray/index.html) for more details. #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "NDArray"] #[type_key = "runtime.NDArray"] pub struct NDArrayContainer { diff --git a/rust/tvm-rt/src/object/mod.rs b/rust/tvm-rt/src/object/mod.rs index 7e6107d7cbd0..8c07ed9f0853 100644 --- a/rust/tvm-rt/src/object/mod.rs +++ b/rust/tvm-rt/src/object/mod.rs @@ -40,6 +40,7 @@ pub trait IsObjectRef: + TryFrom + for<'a> Into> + for<'a> TryFrom, Error = Error> + + std::fmt::Debug { type Object: IsObject; fn as_ptr(&self) -> Option<&ObjectPtr>; diff --git a/rust/tvm-rt/src/object/object_ptr.rs b/rust/tvm-rt/src/object/object_ptr.rs index 0f61bd228988..97c474c29b1f 100644 --- a/rust/tvm-rt/src/object/object_ptr.rs +++ b/rust/tvm-rt/src/object/object_ptr.rs @@ -19,6 +19,7 @@ use std::convert::TryFrom; use std::ffi::CString; +use std::fmt; use std::ptr::NonNull; use std::sync::atomic::AtomicI32; @@ -147,6 +148,18 @@ impl Object { } } +// impl fmt::Debug for Object { +// fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { +// let index = +// format!("{} // key: {}", self.type_index, "the_key"); + +// f.debug_struct("Object") +// .field("type_index", &index) +// // TODO(@jroesch: do we expose other fields?) +// .finish() +// } +// } + /// An unsafe trait which should be implemented for an object /// subtype. /// @@ -154,7 +167,7 @@ impl Object { /// index, a method for accessing the base object given the /// subtype, and a typed delete method which is specialized /// to the subtype. -pub unsafe trait IsObject: AsRef { +pub unsafe trait IsObject: AsRef + std::fmt::Debug { const TYPE_KEY: &'static str; unsafe extern "C" fn typed_delete(object: *mut Self) { @@ -264,6 +277,13 @@ impl std::ops::Deref for ObjectPtr { } } +impl fmt::Debug for ObjectPtr { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + use std::ops::Deref; + write!(f, "{:?}", self.deref()) + } +} + impl<'a, T: IsObject> From> for RetValue { fn from(object_ptr: ObjectPtr) -> RetValue { let raw_object_ptr = ObjectPtr::leak(object_ptr) as *mut T as *mut std::ffi::c_void; diff --git a/rust/tvm-rt/src/string.rs b/rust/tvm-rt/src/string.rs index a650cc994d98..e61afaf7399b 100644 --- a/rust/tvm-rt/src/string.rs +++ b/rust/tvm-rt/src/string.rs @@ -25,7 +25,7 @@ use super::Object; use tvm_macros::Object; #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "String"] #[type_key = "runtime.String"] #[no_derive] diff --git a/rust/tvm/src/ir/arith.rs b/rust/tvm/src/ir/arith.rs index 92a1de69ff78..672e6e6113a0 100644 --- a/rust/tvm/src/ir/arith.rs +++ b/rust/tvm/src/ir/arith.rs @@ -24,7 +24,7 @@ use tvm_macros::Object; macro_rules! define_node { ($name:ident, $ref:expr, $typekey:expr; $node:ident { $($id:ident : $t:ty),*}) => { #[repr(C)] - #[derive(Object)] + #[derive(Object, Debug)] #[ref_name = $ref] #[type_key = $typekey] pub struct $node { diff --git a/rust/tvm/src/ir/attrs.rs b/rust/tvm/src/ir/attrs.rs index 5bd027ab4b4c..739ed405c906 100644 --- a/rust/tvm/src/ir/attrs.rs +++ b/rust/tvm/src/ir/attrs.rs @@ -21,7 +21,7 @@ use crate::runtime::Object; use tvm_macros::Object; #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "Attrs"] #[type_key = "Attrs"] pub struct BaseAttrsNode { diff --git a/rust/tvm/src/ir/diagnostics/mod.rs b/rust/tvm/src/ir/diagnostics/mod.rs index 051bb9eb16c4..8bcdf8f51e60 100644 --- a/rust/tvm/src/ir/diagnostics/mod.rs +++ b/rust/tvm/src/ir/diagnostics/mod.rs @@ -59,6 +59,7 @@ external! { /// The diagnostic level, controls the printing of the message. #[repr(C)] +#[derive(PartialEq, Eq, Debug)] pub enum DiagnosticLevel { Bug = 10, Error = 20, @@ -69,7 +70,7 @@ pub enum DiagnosticLevel { /// A compiler diagnostic. #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "Diagnostic"] #[type_key = "Diagnostic"] pub struct DiagnosticNode { @@ -145,7 +146,7 @@ impl DiagnosticBuilder { /// of compiler diagnostics to std::out and std::err in /// a human readable form. #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "DiagnosticRenderer"] #[type_key = "DiagnosticRenderer"] /// A diagnostic renderer, which given a diagnostic context produces a "rendered" @@ -166,7 +167,7 @@ impl DiagnosticRenderer { } #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "DiagnosticContext"] #[type_key = "DiagnosticContext"] /// A diagnostic context for recording errors against a source file. diff --git a/rust/tvm/src/ir/expr.rs b/rust/tvm/src/ir/expr.rs index 4ce3d790dcae..018af0520777 100644 --- a/rust/tvm/src/ir/expr.rs +++ b/rust/tvm/src/ir/expr.rs @@ -28,7 +28,7 @@ use super::span::Span; #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "BaseExpr"] #[type_key = "Expr"] pub struct BaseExprNode { @@ -44,7 +44,7 @@ impl BaseExprNode { } #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "PrimExpr"] #[type_key = "PrimExpr"] pub struct PrimExprNode { @@ -62,7 +62,7 @@ impl PrimExprNode { } #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "GlobalVar"] #[type_key = "GlobalVar"] pub struct GlobalVarNode { diff --git a/rust/tvm/src/ir/function.rs b/rust/tvm/src/ir/function.rs index 3043bf9e7cff..14c00ea02bf6 100644 --- a/rust/tvm/src/ir/function.rs +++ b/rust/tvm/src/ir/function.rs @@ -28,7 +28,7 @@ use tvm_macros::Object; pub type DictAttrs = ObjectRef; #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "BaseFunc"] #[type_key = "BaseFunc"] pub struct BaseFuncNode { diff --git a/rust/tvm/src/ir/module.rs b/rust/tvm/src/ir/module.rs index 64679e04b3be..cd1db02cf801 100644 --- a/rust/tvm/src/ir/module.rs +++ b/rust/tvm/src/ir/module.rs @@ -17,6 +17,7 @@ * under the License. */ +use std::collections::HashMap; use std::iter::FromIterator; use std::path::Path; @@ -46,7 +47,7 @@ pub enum Error { } #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "IRModule"] #[type_key = "IRModule"] pub struct IRModuleNode { @@ -106,6 +107,12 @@ impl IRModule { module_new(Map::from_iter(funcs), Map::from_iter(types)) } + pub fn empty() -> Result { + let funcs = HashMap::::new(); + let types = HashMap::::new(); + IRModule::new(funcs, types) + } + pub fn parse(file_name: N, source: S) -> Result where N: Into, @@ -140,8 +147,9 @@ impl IRModule { module_add_def(self.clone(), type_name, type_data, update) } - pub fn get_global_var(&self, name: TVMString) -> Result { - module_get_global_var(self.clone(), name) + pub fn get_global_var(&self, name: S) -> Result + where S: Into { + module_get_global_var(self.clone(), name.into()) } pub fn get_global_vars(&self) -> Result> { @@ -206,9 +214,7 @@ mod tests { #[test] fn test_module_add() -> anyhow::Result<()> { - let funcs = HashMap::::new(); - let types = HashMap::::new(); - let mut module = IRModule::new(funcs, types)?; + let mut module = IRModule::empty()?; let x = Var::static_tensor("x".into(), vec![1, 1], DataType::float32()); let params = Array::from_vec(vec![x.clone()])?; let func = relay::Function::simple(params, x.upcast()).upcast(); @@ -225,8 +231,16 @@ mod tests { } #[test] - fn test_get_global_var() { - + fn test_get_global_var() -> Result<()> { + let mut module = IRModule::empty()?; + let x = Var::static_tensor("x".into(), vec![1, 1], DataType::float32()); + let params = vec![x.clone()]; + let func = relay::Function::simple(params, x.upcast()).upcast(); + let gv_foo = GlobalVar::new("foo".into(), Span::null()); + let module = module.add(gv_foo, func)?; + let gv = module.get_global_var("foo"); + assert_eq!(gv_foo, gv); + Ok(()) } #[test] diff --git a/rust/tvm/src/ir/op.rs b/rust/tvm/src/ir/op.rs index d81d6a69c1eb..d222ead0391b 100644 --- a/rust/tvm/src/ir/op.rs +++ b/rust/tvm/src/ir/op.rs @@ -27,7 +27,7 @@ type FuncType = ObjectRef; type AttrFieldInfo = ObjectRef; #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "Op"] #[type_key = "Op"] pub struct OpNode { diff --git a/rust/tvm/src/ir/relay/attrs/nn.rs b/rust/tvm/src/ir/relay/attrs/nn.rs index cb96f0fbf588..7ecd92febc22 100644 --- a/rust/tvm/src/ir/relay/attrs/nn.rs +++ b/rust/tvm/src/ir/relay/attrs/nn.rs @@ -27,7 +27,7 @@ use tvm_macros::Object; type IndexExpr = PrimExpr; #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "Conv2DAttrs"] #[type_key = "relay.attrs.Conv2DAttrs"] pub struct Conv2DAttrsNode { @@ -46,7 +46,7 @@ pub struct Conv2DAttrsNode { } #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "BiasAddAttrs"] #[type_key = "relay.attrs.BiasAddAttrs"] pub struct BiasAddAttrsNode { @@ -55,7 +55,7 @@ pub struct BiasAddAttrsNode { } #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "DenseAttrs"] #[type_key = "relay.attrs.DenseAttrs"] pub struct DenseAttrsNode { @@ -65,7 +65,7 @@ pub struct DenseAttrsNode { } #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "GlobalPool2DAttrs"] #[type_key = "relay.attrs.GlobalPool2DAttrs"] pub struct GlobalPool2DAttrsNode { @@ -74,7 +74,7 @@ pub struct GlobalPool2DAttrsNode { } #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "MaxPool2DAttrs"] #[type_key = "relay.attrs.MaxPool2DAttrs"] pub struct MaxPool2DAttrsNode { @@ -87,7 +87,7 @@ pub struct MaxPool2DAttrsNode { } #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "SoftmaxAttrs"] #[type_key = "relay.attrs.SoftmaxAttrs"] pub struct SoftmaxAttrsNode { @@ -96,7 +96,7 @@ pub struct SoftmaxAttrsNode { } #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "BatchNormAttrs"] #[type_key = "relay.attrs.BatchNormAttrs"] pub struct BatchNormAttrsNode { diff --git a/rust/tvm/src/ir/relay/attrs/transform.rs b/rust/tvm/src/ir/relay/attrs/transform.rs index 863f07617778..c459f96b2d2f 100644 --- a/rust/tvm/src/ir/relay/attrs/transform.rs +++ b/rust/tvm/src/ir/relay/attrs/transform.rs @@ -21,7 +21,7 @@ use crate::ir::attrs::BaseAttrsNode; use tvm_macros::Object; #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "ExpandDimsAttrs"] #[type_key = "relay.attrs.ExpandDimsAttrs"] pub struct ExpandDimsAttrsNode { diff --git a/rust/tvm/src/ir/relay/mod.rs b/rust/tvm/src/ir/relay/mod.rs index a89a1d6e0f9a..cf1a628607a8 100644 --- a/rust/tvm/src/ir/relay/mod.rs +++ b/rust/tvm/src/ir/relay/mod.rs @@ -34,7 +34,7 @@ pub use crate::runtime::DataType; pub mod attrs; #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "Expr"] #[type_key = "RelayExpr"] pub struct ExprNode { @@ -57,7 +57,7 @@ impl ExprNode { } #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "Id"] #[type_key = "relay.Id"] pub struct IdNode { @@ -76,7 +76,7 @@ impl Id { } #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "Constant"] #[type_key = "relay.Constant"] pub struct ConstantNode { @@ -95,7 +95,7 @@ impl Constant { } #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "Tuple"] #[type_key = "relay.Tuple"] pub struct TupleNode { @@ -114,7 +114,7 @@ impl Tuple { } #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "Var"] #[type_key = "relay.Var"] pub struct VarNode { @@ -144,7 +144,7 @@ impl Var { } #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "Call"] #[type_key = "relay.Call"] pub struct CallNode { @@ -175,7 +175,7 @@ impl Call { } #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "Let"] #[type_key = "relay.Let"] pub struct LetNode { @@ -198,7 +198,7 @@ impl Let { } #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "If"] #[type_key = "relay.If"] pub struct IfNode { @@ -221,7 +221,7 @@ impl If { } #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "TupleGetItem"] #[type_key = "relay.TupleGetItem"] pub struct TupleGetItemNode { @@ -242,7 +242,7 @@ impl TupleGetItem { } #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "RefCreate"] #[type_key = "relay.RefCreate"] pub struct RefCreateNode { @@ -261,7 +261,7 @@ impl RefCreate { } #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "RefRead"] #[type_key = "relay.RefRead"] pub struct RefReadNode { @@ -280,7 +280,7 @@ impl RefRead { } #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "RefWrite"] #[type_key = "relay.RefWrite"] pub struct RefWriteNode { @@ -301,7 +301,7 @@ impl RefWrite { } #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "Constructor"] #[type_key = "relay.Constructor"] pub struct ConstructorNode { @@ -326,7 +326,7 @@ impl Constructor { // TODO(@jroesch): define the type data #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "Pattern"] #[type_key = "relay.Pattern"] pub struct PatternNode { @@ -344,7 +344,7 @@ impl PatternNode { } #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "PatternWildcard"] #[type_key = "relay.PatternWildcard"] pub struct PatternWildcardNode { @@ -361,7 +361,7 @@ impl PatternWildcard { } #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "PatternVar"] #[type_key = "relay.PatternVar"] pub struct PatternVarNode { @@ -380,7 +380,7 @@ impl PatternVar { } #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "PatternConstructor"] #[type_key = "relay.PatternConstructor"] pub struct PatternConstructorNode { @@ -405,7 +405,7 @@ impl PatternConstructor { } #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "PatternTuple"] #[type_key = "relay.PatternTuple"] pub struct PatternTupleNode { @@ -424,7 +424,7 @@ impl PatternTuple { } #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "Clause"] #[type_key = "relay.Clause"] pub struct ClauseNode { @@ -445,7 +445,7 @@ impl Clause { } #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "Match"] #[type_key = "relay.Match"] pub struct MatchNode { @@ -468,7 +468,7 @@ impl Match { } #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "Function"] #[type_key = "relay.Function"] pub struct FunctionNode { @@ -496,7 +496,8 @@ impl Function { Function(Some(ObjectPtr::new(node))) } - pub fn simple(params: Array, body: Expr) -> Function { + pub fn simple(params: Vec, body: Expr) -> Function { + let params = Array::from_vec(params).unwrap(); Self::new(params, body, Type::null(), Array::from_vec(vec![]).unwrap()) } } diff --git a/rust/tvm/src/ir/source_map.rs b/rust/tvm/src/ir/source_map.rs index 54e16dac62ac..7376f4b74022 100644 --- a/rust/tvm/src/ir/source_map.rs +++ b/rust/tvm/src/ir/source_map.rs @@ -29,7 +29,7 @@ use tvm_macros::Object; /// /// Could represent the source from an ML framework or a source of an IRModule. #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[type_key = "Source"] #[ref_name = "Source"] pub struct SourceNode { @@ -46,7 +46,7 @@ pub struct SourceNode { /// A mapping from a unique source name to source fragments. #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[type_key = "SourceMap"] #[ref_name = "SourceMap"] pub struct SourceMapNode { diff --git a/rust/tvm/src/ir/span.rs b/rust/tvm/src/ir/span.rs index eb6821af69dc..be74745b60ca 100644 --- a/rust/tvm/src/ir/span.rs +++ b/rust/tvm/src/ir/span.rs @@ -23,7 +23,7 @@ use tvm_macros::Object; /// A source file name, contained in a Span. #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[type_key = "SourceName"] #[ref_name = "SourceName"] pub struct SourceNameNode { @@ -33,7 +33,7 @@ pub struct SourceNameNode { /// Span information for diagnostic purposes. #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[type_key = "Span"] #[ref_name = "Span"] pub struct SpanNode { diff --git a/rust/tvm/src/ir/tir.rs b/rust/tvm/src/ir/tir.rs index f07e85486626..ccbe30c95820 100644 --- a/rust/tvm/src/ir/tir.rs +++ b/rust/tvm/src/ir/tir.rs @@ -26,7 +26,7 @@ use tvm_macros::Object; macro_rules! define_node { ($name:ident, $ref:expr, $typekey:expr; $node:ident { $($id:ident : $t:ty),*}) => { #[repr(C)] - #[derive(Object)] + #[derive(Object, Debug)] #[ref_name = $ref] #[type_key = $typekey] pub struct $node { diff --git a/rust/tvm/src/ir/ty.rs b/rust/tvm/src/ir/ty.rs index 57f2a790b7d5..488195ab4e6e 100644 --- a/rust/tvm/src/ir/ty.rs +++ b/rust/tvm/src/ir/ty.rs @@ -26,7 +26,7 @@ use super::PrimExpr; use super::relay::Constructor; #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "Type"] #[type_key = "Type"] pub struct TypeNode { @@ -52,7 +52,7 @@ impl TypeNode { * \sa PrimType */ #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "PrimType"] #[type_key = "PrimType"] pub struct PrimTypeNode { @@ -74,7 +74,7 @@ pub struct PrimTypeNode { */ #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "PointerType"] #[type_key = "PointerType"] pub struct PointerTypeNode { @@ -83,7 +83,9 @@ pub struct PointerTypeNode { pub element_type: Type, } + /// Possible kinds of type variables. +#[derive(PartialEq, Eq, Debug)] pub enum TypeKind { Type = 0, /// Template variable in shape expression. @@ -93,26 +95,16 @@ pub enum TypeKind { TypeData = 6, } -/* - * \brief Type parameter in functions. - * - * A type variable can be viewed as template parameter in c++ template function. - * - * For example, in the following pesudo code, - * the TypeVar of f is TypeVar("n", kind=kShapeVar). - * This function can take in a Tensor with shape=(3, 3) and - * returns a Tensor with shape=(9,) - * - * \code - * - * template - * f(x : Tensor[i32, (n, n)]) -> Tensor[i32, (n * n)] - * - * \endcode - * \sa TypeVar, TypeKind - */ +/// Type parameter in functions. +/// +/// A type variable can be viewed as template parameter in c++ template function. +/// +/// For example, in the following pesudo code, +/// the TypeVar of f is TypeVar("n", kind=kShapeVar). +/// This function can take in a Tensor with shape=(3, 3) and +/// returns a Tensor with shape=(9,) #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "TypeVar"] #[type_key = "TypeVar"] pub struct TypeVarNode { @@ -123,7 +115,7 @@ pub struct TypeVarNode { /// A global type variable that is used for defining new types or type aliases. #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "GlobalTypeVar"] #[type_key = "GlobalTypeVar"] pub struct GlobalTypeVarNode { @@ -133,7 +125,7 @@ pub struct GlobalTypeVarNode { } #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "TupleType"] #[type_key = "TupleType"] pub struct TupleTypeNode { @@ -148,7 +140,7 @@ impl TupleType { } #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "TypeConstraint"] #[type_key = "TypeConstraint"] pub struct TypeConstraintNode { @@ -157,7 +149,7 @@ pub struct TypeConstraintNode { /// The representation of a polymorphic function type. #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "FuncType"] #[type_key = "FuncType"] pub struct FuncTypeNode { @@ -182,7 +174,7 @@ pub struct FuncTypeNode { * TypeVar represents the input to the graph. */ #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "IncompleteType"] #[type_key = "IncompleteType"] pub struct IncompleteTypeNode { @@ -196,7 +188,7 @@ pub struct IncompleteTypeNode { * \sa RelayRefType. */ #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "RefType"] #[type_key = "relay.RefType"] pub struct RelayRefTypeNode { @@ -205,7 +197,7 @@ pub struct RelayRefTypeNode { } #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "BaseTensorType"] #[type_key = "relay.BaseTensorType"] pub struct BaseTensorTypeNode { @@ -213,7 +205,7 @@ pub struct BaseTensorTypeNode { } #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "TensorType"] #[type_key = "relay.TensorType"] pub struct TensorTypeNode { @@ -253,7 +245,7 @@ an ADT definition is treated as a type-level function, so an ADT handle must be wrapped in a TypeCall node that instantiates the type-level arguments. The kind checker enforces this. */ #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "TypeData"] #[type_key = "relay.TypeData"] pub struct TypeDataNode { diff --git a/rust/tvm/src/transform.rs b/rust/tvm/src/transform.rs index c5a65c417c93..b49633777b65 100644 --- a/rust/tvm/src/transform.rs +++ b/rust/tvm/src/transform.rs @@ -33,7 +33,7 @@ pub type IRModule = ObjectRef; pub type PassContext = ObjectRef; #[repr(C)] -#[derive(Object)] +#[derive(Object, Debug)] #[ref_name = "PassInfo"] #[type_key = "transform.PassInfo"] pub struct PassInfoNode { From fdaa4af44dfc5a69a78e361dd82f4bc397307862 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Wed, 4 Nov 2020 17:56:17 -0800 Subject: [PATCH 08/12] Write some more tests for IRModule --- rust/tvm-rt/src/array.rs | 9 +++- rust/tvm/src/ir/module.rs | 102 ++++++++++++++++++++--------------- rust/tvm/src/ir/relay/mod.rs | 7 +-- rust/tvm/src/ir/ty.rs | 47 ++++++++++++---- 4 files changed, 107 insertions(+), 58 deletions(-) diff --git a/rust/tvm-rt/src/array.rs b/rust/tvm-rt/src/array.rs index 1850cebbb4ec..a6afb54ecd30 100644 --- a/rust/tvm-rt/src/array.rs +++ b/rust/tvm-rt/src/array.rs @@ -18,7 +18,7 @@ */ use std::convert::{TryFrom, TryInto}; -use std::iter::{IntoIterator, Iterator}; +use std::iter::{IntoIterator, Iterator, FromIterator}; use std::marker::PhantomData; use crate::errors::Error; @@ -125,6 +125,13 @@ impl IntoIterator for Array { } } +impl FromIterator for Array { + fn from_iter>(iter: I) -> Self { + Array::from_vec(iter.into_iter().collect()).unwrap() + } +} + + impl From> for ArgValue<'static> { fn from(array: Array) -> ArgValue<'static> { array.object.into() diff --git a/rust/tvm/src/ir/module.rs b/rust/tvm/src/ir/module.rs index cd1db02cf801..6136b449a968 100644 --- a/rust/tvm/src/ir/module.rs +++ b/rust/tvm/src/ir/module.rs @@ -28,15 +28,12 @@ use crate::runtime::array::Array; use crate::runtime::function::Result; use crate::runtime::map::Map; use crate::runtime::string::String as TVMString; -use crate::runtime::{external, Object, ObjectRef}; +use crate::runtime::{external, Object, IsObjectRef}; use super::expr::GlobalVar; -use super::function::BaseFunc; +use super::function::{BaseFunc}; use super::source_map::SourceMap; -use super::{ty::GlobalTypeVar, relay}; - -// TODO(@jroesch): define type -type TypeData = ObjectRef; +use super::{ty::GlobalTypeVar, ty::TypeData, relay}; #[derive(Error, Debug)] pub enum Error { @@ -88,7 +85,7 @@ external! { #[name("ir.Module_LookupDef")] fn module_lookup_def(module: IRModule, global: GlobalTypeVar) -> TypeData; #[name("ir.Module_LookupDef_str")] - fn module_lookup_def_str(module: IRModule, global: GlobalTypeVar) -> TypeData; + fn module_lookup_def_str(module: IRModule, global: TVMString) -> TypeData; #[name("ir.Module_LookupTag")] fn module_lookup_tag(module: IRModule, tag: i32) -> relay::Constructor; #[name("ir.Module_FromExpr")] @@ -131,11 +128,13 @@ impl IRModule { Ok(module) } - pub fn add( + pub fn add( &mut self, var: GlobalVar, - func: BaseFunc) -> Result { - module_add(self.clone(), var, func, true) + func: F) -> Result + // todo(@jroesch): can we do better here? why doesn't BaseFunc::Object work? + where F: IsObjectRef, F::Object: AsRef<::Object> { + module_add(self.clone(), var, func.upcast(), true) } pub fn add_def( @@ -183,8 +182,9 @@ impl IRModule { module_lookup_def(self.clone(), global) } - pub fn lookup_def_str(&self, global: GlobalTypeVar) -> Result { - module_lookup_def_str(self.clone(), global) + pub fn lookup_def_str(&self, global: S) -> Result + where S: Into { + module_lookup_def_str(self.clone(), global.into()) } pub fn lookup_tag(&self, tag: i32) -> Result { @@ -206,18 +206,18 @@ impl IRModule { #[cfg(test)] mod tests { - use std::collections::HashMap; use super::relay::*; use super::*; - use super::super::span::Span; use tvm_rt::IsObjectRef; + use crate::ir::ty::{GlobalTypeVar, TypeData, TypeKind}; + use crate::ir::span::Span; #[test] fn test_module_add() -> anyhow::Result<()> { let mut module = IRModule::empty()?; let x = Var::static_tensor("x".into(), vec![1, 1], DataType::float32()); - let params = Array::from_vec(vec![x.clone()])?; - let func = relay::Function::simple(params, x.upcast()).upcast(); + let params = vec![x.clone()]; + let func = relay::Function::simple(params, x); let module = module.add(GlobalVar::new("foo".into(), Span::null()), func)?; let lfunc = module.lookup_str("foo")?; let lfunc = lfunc.downcast::()?; @@ -226,8 +226,14 @@ mod tests { } #[test] - fn test_module_add_def() { - todo!("this is blocked on having ability to define ADTs") + fn test_module_add_def() -> Result<()> { + let mut module = IRModule::empty()?; + let name = GlobalTypeVar::new("my_type", TypeKind::Type, Span::null()); + let type_data = TypeData::new(name.clone(), vec![], vec![]); + module.add_def(name.clone(), type_data, true)?; + let by_gtv = module.lookup_def(name)?; + let by_gv = module.lookup_def_str("my_type")?; + Ok(()) } #[test] @@ -235,17 +241,25 @@ mod tests { let mut module = IRModule::empty()?; let x = Var::static_tensor("x".into(), vec![1, 1], DataType::float32()); let params = vec![x.clone()]; - let func = relay::Function::simple(params, x.upcast()).upcast(); + let func = relay::Function::simple(params, x); let gv_foo = GlobalVar::new("foo".into(), Span::null()); - let module = module.add(gv_foo, func)?; - let gv = module.get_global_var("foo"); + let module = module.add(gv_foo.clone(), func)?; + let gv = module.get_global_var("foo")?; assert_eq!(gv_foo, gv); Ok(()) } #[test] - fn test_get_global_vars() { - + fn test_get_global_vars() -> Result<()> { + let mut module = IRModule::empty()?; + let x = Var::static_tensor("x".into(), vec![1, 1], DataType::float32()); + let params = vec![x.clone()]; + let func = relay::Function::simple(params, x); + let gv_foo = GlobalVar::new("foo".into(), Span::null()); + let module = module.add(gv_foo.clone(), func)?; + let gv = module.get_global_var("foo")?; + assert_eq!(gv_foo, gv); + Ok(()) } #[test] @@ -253,11 +267,6 @@ mod tests { } - #[test] - fn test_lookup() { - - } - #[test] fn test_contains_global_var() { } @@ -266,29 +275,34 @@ mod tests { fn test_contains_global_type_var() { } - #[test] - fn test_lookup_def() { - - } - - #[test] - fn lookup_def() { - - } - + // TODO(@jroesch): not really sure about this API at all. // pub fn lookup_tag(&self, tag: i32) -> Result { // module_lookup_tag(self.clone(), tag) // } + // TODO(@jroesch): do we need to test this? // pub fn from_expr(expr: relay::Expr, funcs: Map, types: Map) -> Result { // module_from_expr(expr, funcs, types) // } - // pub fn import>(&mut self, path: S) -> Result<()> { - // module_import(self.clone(), path.into()) - // } + #[test] + fn test_import() -> Result<()> { + let mut std_path: String = env!("CARGO_MANIFEST_DIR").into(); + std_path += "/../../python/tvm/relay/std/prelude.rly"; - // pub fn import_from_std>(&mut self, path: S) -> Result<()> { - // module_import_from_std(self.clone(), path.into()) - // } + let mut mod1 = IRModule::empty()?; + mod1.import(std_path.clone())?; + mod1.lookup_str("map")?; + + // TODO(@jroesch): this requires another patch of mine to enable. + + // if cfg!(feature = "python") { + // crate::python::load().unwrap(); + // let mut mod2 = IRModule::empty()?; + // mod2.import_from_std("prelude.rly")?; + // mod2.lookup_str("map")?; + // } + + Ok(()) + } } diff --git a/rust/tvm/src/ir/relay/mod.rs b/rust/tvm/src/ir/relay/mod.rs index cf1a628607a8..8dd96173e5a8 100644 --- a/rust/tvm/src/ir/relay/mod.rs +++ b/rust/tvm/src/ir/relay/mod.rs @@ -496,9 +496,10 @@ impl Function { Function(Some(ObjectPtr::new(node))) } - pub fn simple(params: Vec, body: Expr) -> Function { + pub fn simple(params: Vec, body: E) -> Function + where E: IsObjectRef, E::Object: AsRef<::Object> { let params = Array::from_vec(params).unwrap(); - Self::new(params, body, Type::null(), Array::from_vec(vec![]).unwrap()) + Self::new(params, body.upcast(), Type::null(), Array::from_vec(vec![]).unwrap()) } } @@ -547,7 +548,7 @@ def @main() -> float32 { ) .unwrap(); let main = module - .lookup(module.get_global_var("main".to_string().into()).unwrap()) + .lookup(module.get_global_var("main").unwrap()) .unwrap(); let func = main.downcast::().unwrap(); let constant = func diff --git a/rust/tvm/src/ir/ty.rs b/rust/tvm/src/ir/ty.rs index 488195ab4e6e..d59940ff985e 100644 --- a/rust/tvm/src/ir/ty.rs +++ b/rust/tvm/src/ir/ty.rs @@ -17,13 +17,14 @@ * under the License. */ -use super::span::Span; -use crate::runtime::{IsObject, Object, ObjectPtr}; + use tvm_macros::Object; use tvm_rt::{array::Array, DataType}; -use super::PrimExpr; -use super::relay::Constructor; +use crate::ir::span::Span; +use crate::ir::relay::Constructor; +use crate::ir::PrimExpr; +use crate::runtime::{IsObject, Object, ObjectPtr}; #[repr(C)] #[derive(Object, Debug)] @@ -124,6 +125,18 @@ pub struct GlobalTypeVarNode { pub kind: TypeKind, } +impl GlobalTypeVar { + pub fn new(name_hint: S, kind: TypeKind, span: Span) -> GlobalTypeVar + where S: Into { + let node = GlobalTypeVarNode { + base: TypeNode::base::(span), + name_hint: name_hint.into(), + kind: kind, + }; + ObjectPtr::new(node).into() + } +} + #[repr(C)] #[derive(Object, Debug)] #[ref_name = "TupleType"] @@ -249,12 +262,10 @@ The kind checker enforces this. */ #[ref_name = "TypeData"] #[type_key = "relay.TypeData"] pub struct TypeDataNode { - // /*! - // * \brief The header is simply the name of the ADT. - // * We adopt nominal typing for ADT definitions; - // * that is, differently-named ADT definitions with same constructors - // * have different types. - // */ + /// The header is simply the name of the ADT. + /// We adopt nominal typing for ADT definitions; + /// that is, differently-named ADT definitions with same constructors + /// have different types. pub base: Object, pub type_name: GlobalTypeVar, /// The type variables (to allow for polymorphism). @@ -262,3 +273,19 @@ pub struct TypeDataNode { /// The constructors. pub constructors: Array, } + +impl TypeData { + pub fn new(type_name: GlobalTypeVar, type_vars: TypeVars, constructors: Ctors) -> TypeData + where TypeVars: IntoIterator, + Ctors: IntoIterator, + { + use std::iter::FromIterator; + let type_data = TypeDataNode { + base: Object::base::(), + type_name, + type_vars: Array::from_iter(type_vars), + constructors: Array::from_iter(constructors), + }; + TypeData(Some(ObjectPtr::new(type_data))) + } +} From 726a01b3115407f4d5923e2c3f55f7477ecde9cf Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Wed, 4 Nov 2020 19:05:20 -0800 Subject: [PATCH 09/12] All tests pass --- rust/tvm-rt/src/map.rs | 2 - rust/tvm/src/ir/module.rs | 100 +++++++++++++++++++++++++++++++------- rust/tvm/src/ir/ty.rs | 14 +++--- src/ir/module.cc | 8 ++- 4 files changed, 96 insertions(+), 28 deletions(-) diff --git a/rust/tvm-rt/src/map.rs b/rust/tvm-rt/src/map.rs index 721fb1ec4588..b8bfb4e5e644 100644 --- a/rust/tvm-rt/src/map.rs +++ b/rust/tvm-rt/src/map.rs @@ -48,8 +48,6 @@ where // TODO(@jroesch): convert to use generics instead of casting inside // the implementation. external! { - #[name("node.ArrayGetItem")] - fn array_get_item(array: ObjectRef, index: isize) -> ObjectRef; #[name("node.MapSize")] fn map_size(map: ObjectRef) -> i64; #[name("node.MapGetItem")] diff --git a/rust/tvm/src/ir/module.rs b/rust/tvm/src/ir/module.rs index 6136b449a968..a49e587b7be9 100644 --- a/rust/tvm/src/ir/module.rs +++ b/rust/tvm/src/ir/module.rs @@ -191,10 +191,20 @@ impl IRModule { module_lookup_tag(self.clone(), tag) } - pub fn from_expr(expr: relay::Expr, funcs: Map, types: Map) -> Result { - module_from_expr(expr, funcs, types) + pub fn from_expr(expr: E) -> Result + where E: IsObjectRef, E::Object: AsRef<::Object> { + Self::from_expr_with_items(expr, HashMap::new(), HashMap::new()) } + pub fn from_expr_with_items(expr: E, funcs: F, types: T) -> Result + where F: IntoIterator, + T: IntoIterator, + E: IsObjectRef, + E::Object: AsRef<::Object> { + module_from_expr(expr.upcast(), Map::from_iter(funcs), Map::from_iter(types)) + } + + pub fn import>(&mut self, path: S) -> Result<()> { module_import(self.clone(), path.into()) } @@ -212,6 +222,33 @@ mod tests { use crate::ir::ty::{GlobalTypeVar, TypeData, TypeKind}; use crate::ir::span::Span; + fn add_dummy_functions(names: Vec<&str>) -> Result { + let mut module = IRModule::empty()?; + let x = Var::static_tensor("x".into(), vec![1, 1], DataType::float32()); + let params = vec![x.clone()]; + let func = relay::Function::simple(params, x); + + for name in names { + let gv = GlobalVar::new(name.into(), Span::null()); + module = module.add(gv, func.clone())?; + } + + Ok(module) + } + + fn add_dummy_types(names: Vec<&str>) -> Result { + let mut module = IRModule::empty()?; + + for name in names { + let name: String = name.into(); + let name = GlobalTypeVar::new(name, TypeKind::Type, Span::null()); + let type_data = TypeData::new(name.clone(), vec![], vec![], Span::null()); + module.add_def(name, type_data, true)?; + } + + Ok(module) + } + #[test] fn test_module_add() -> anyhow::Result<()> { let mut module = IRModule::empty()?; @@ -229,7 +266,7 @@ mod tests { fn test_module_add_def() -> Result<()> { let mut module = IRModule::empty()?; let name = GlobalTypeVar::new("my_type", TypeKind::Type, Span::null()); - let type_data = TypeData::new(name.clone(), vec![], vec![]); + let type_data = TypeData::new(name.clone(), vec![], vec![], Span::null()); module.add_def(name.clone(), type_data, true)?; let by_gtv = module.lookup_def(name)?; let by_gv = module.lookup_def_str("my_type")?; @@ -251,28 +288,48 @@ mod tests { #[test] fn test_get_global_vars() -> Result<()> { - let mut module = IRModule::empty()?; - let x = Var::static_tensor("x".into(), vec![1, 1], DataType::float32()); - let params = vec![x.clone()]; - let func = relay::Function::simple(params, x); - let gv_foo = GlobalVar::new("foo".into(), Span::null()); - let module = module.add(gv_foo.clone(), func)?; - let gv = module.get_global_var("foo")?; - assert_eq!(gv_foo, gv); + let names = vec!["foo", "bar", "baz"]; + let module = add_dummy_functions(names.clone())?; + let gvars: Vec = + module.get_global_vars()?.into_iter().map(|gv| { + gv.name_hint.as_str().unwrap().to_string() + }).collect(); + + for name in names { + assert!(gvars.contains(&name.to_string())); + } + Ok(()) } #[test] - fn test_get_global_type_vars() { + fn test_get_global_type_vars() -> Result<()> { + let names = vec!["foo", "bar", "baz"]; + let module = add_dummy_types(names.clone())?; + let gvars: Vec = + module.get_global_type_vars()?.into_iter().map(|gv| { + gv.name_hint.as_str().unwrap().to_string() + }).collect(); + + for name in names { + assert!(gvars.contains(&name.to_string())); + } + Ok(()) } #[test] - fn test_contains_global_var() { + fn test_contains_global_var() -> Result<()> { + let module = add_dummy_functions(vec!["foo"])?; + assert!(module.contains_global_var("foo")?); + Ok(()) } #[test] - fn test_contains_global_type_var() { + fn test_contains_global_type_var() -> Result<()> { + let module = add_dummy_types(vec!["foo"])?; + assert!(module.contains_global_type_var("foo")?); + Ok(()) } // TODO(@jroesch): not really sure about this API at all. @@ -280,10 +337,17 @@ mod tests { // module_lookup_tag(self.clone(), tag) // } - // TODO(@jroesch): do we need to test this? - // pub fn from_expr(expr: relay::Expr, funcs: Map, types: Map) -> Result { - // module_from_expr(expr, funcs, types) - // } + #[test] + fn test_from_expr() -> Result<()> { + let x = Var::static_tensor("x".into(), vec![1, 1], DataType::float32()); + let params = vec![x.clone()]; + let func = relay::Function::simple(params, x); + let module = IRModule::from_expr(func.clone())?; + let main_fn = module.lookup_str("main")?; + let main_fn = main_fn.downcast::()?; + assert_eq!(main_fn, func); + Ok(()) + } #[test] fn test_import() -> Result<()> { diff --git a/rust/tvm/src/ir/ty.rs b/rust/tvm/src/ir/ty.rs index d59940ff985e..f5b9cb12fd98 100644 --- a/rust/tvm/src/ir/ty.rs +++ b/rust/tvm/src/ir/ty.rs @@ -24,7 +24,7 @@ use tvm_rt::{array::Array, DataType}; use crate::ir::span::Span; use crate::ir::relay::Constructor; use crate::ir::PrimExpr; -use crate::runtime::{IsObject, Object, ObjectPtr}; +use crate::runtime::{IsObject, Object, ObjectPtr, string::String as TString}; #[repr(C)] #[derive(Object, Debug)] @@ -110,7 +110,7 @@ pub enum TypeKind { #[type_key = "TypeVar"] pub struct TypeVarNode { pub base: TypeNode, - pub name_hint: String, + pub name_hint: TString, pub kind: TypeKind, } @@ -121,13 +121,13 @@ pub struct TypeVarNode { #[type_key = "GlobalTypeVar"] pub struct GlobalTypeVarNode { pub base: TypeNode, - pub name_hint: String, + pub name_hint: TString, pub kind: TypeKind, } impl GlobalTypeVar { pub fn new(name_hint: S, kind: TypeKind, span: Span) -> GlobalTypeVar - where S: Into { + where S: Into { let node = GlobalTypeVarNode { base: TypeNode::base::(span), name_hint: name_hint.into(), @@ -266,7 +266,7 @@ pub struct TypeDataNode { /// We adopt nominal typing for ADT definitions; /// that is, differently-named ADT definitions with same constructors /// have different types. - pub base: Object, + pub base: TypeNode, pub type_name: GlobalTypeVar, /// The type variables (to allow for polymorphism). pub type_vars: Array, @@ -275,13 +275,13 @@ pub struct TypeDataNode { } impl TypeData { - pub fn new(type_name: GlobalTypeVar, type_vars: TypeVars, constructors: Ctors) -> TypeData + pub fn new(type_name: GlobalTypeVar, type_vars: TypeVars, constructors: Ctors, span: Span) -> TypeData where TypeVars: IntoIterator, Ctors: IntoIterator, { use std::iter::FromIterator; let type_data = TypeDataNode { - base: Object::base::(), + base: TypeNode::base::(span,), type_name, type_vars: Array::from_iter(type_vars), constructors: Array::from_iter(constructors), diff --git a/src/ir/module.cc b/src/ir/module.cc index b011f2d2f664..25c21e81fcf9 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -425,7 +425,9 @@ TVM_REGISTER_GLOBAL("ir.Module_Add").set_body([](TVMArgs args, TVMRetValue* ret) *ret = mod; }); -TVM_REGISTER_GLOBAL("ir.Module_AddDef").set_body_method(&IRModuleNode::AddTypeDef); +TVM_REGISTER_GLOBAL("ir.Module_AddDef").set_body_typed([](IRModule module, GlobalTypeVar var, TypeData type_def, bool update) { + module->AddTypeDef(var, type_def, update); +}); TVM_REGISTER_GLOBAL("ir.Module_GetGlobalVar") .set_body_method(&IRModuleNode::GetGlobalVar); @@ -439,6 +441,10 @@ TVM_REGISTER_GLOBAL("ir.Module_GetGlobalTypeVars") TVM_REGISTER_GLOBAL("ir.Module_ContainGlobalVar") .set_body_method(&IRModuleNode::ContainGlobalVar); +TVM_REGISTER_GLOBAL("ir.Module_ContainGlobalTypeVar") + .set_body_method(&IRModuleNode::ContainGlobalTypeVar); + + TVM_REGISTER_GLOBAL("ir.Module_GetGlobalTypeVar") .set_body_method(&IRModuleNode::GetGlobalTypeVar); From 79275b304c6a80d1eb34a57e1830119bf7f879ef Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Wed, 4 Nov 2020 19:05:54 -0800 Subject: [PATCH 10/12] Format --- rust/tvm-macros/src/external.rs | 17 ++++--- rust/tvm-macros/src/object.rs | 5 +- rust/tvm-rt/src/array.rs | 5 +- rust/tvm-rt/src/object/object_ptr.rs | 4 +- rust/tvm/src/ir/expr.rs | 1 - rust/tvm/src/ir/module.rs | 75 ++++++++++++++++------------ rust/tvm/src/ir/relay/mod.rs | 18 +++++-- rust/tvm/src/ir/ty.rs | 24 +++++---- src/ir/module.cc | 8 +-- 9 files changed, 96 insertions(+), 61 deletions(-) diff --git a/rust/tvm-macros/src/external.rs b/rust/tvm-macros/src/external.rs index 51a389b2f2f2..146f9d4d6bc6 100644 --- a/rust/tvm-macros/src/external.rs +++ b/rust/tvm-macros/src/external.rs @@ -21,7 +21,10 @@ use proc_macro_error::abort; use quote::quote; use syn::parse::{Parse, ParseStream, Result}; -use syn::{FnArg, Signature, Attribute, token::Semi, Visibility, Generics, Ident, Lit, Meta, NestedMeta, Pat, ReturnType, Type}; +use syn::{ + token::Semi, Attribute, FnArg, Generics, Ident, Lit, Meta, NestedMeta, Pat, ReturnType, + Signature, Type, Visibility, +}; struct ExternalItem { attrs: Vec, @@ -133,7 +136,7 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> proc_macro::TokenStream { syn::GenericParam::Type(param) => param.clone(), _ => abort! { ty_param, "Only supports type parameters." - } + }, }) .collect(); @@ -148,15 +151,15 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> proc_macro::TokenStream { let ty: Type = *pat_type.ty.clone(); (ident, ty) } - _ => abort! { pat_type, + _ => abort! { pat_type, "Only supports type parameters." - } + }, }, pat => abort! { - pat, "invalid pattern type for function"; + pat, "invalid pattern type for function"; - note = "{:?} is not allowed here", pat; - } + note = "{:?} is not allowed here", pat; + }, }) .unzip(); diff --git a/rust/tvm-macros/src/object.rs b/rust/tvm-macros/src/object.rs index 50793a9988aa..c84d0aab612f 100644 --- a/rust/tvm-macros/src/object.rs +++ b/rust/tvm-macros/src/object.rs @@ -36,7 +36,9 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream { .map(attr_to_str) .expect("Failed to get type_key"); - let derive = get_attr(&derive_input, "no_derive").map(|_| false).unwrap_or(true); + let derive = get_attr(&derive_input, "no_derive") + .map(|_| false) + .unwrap_or(true); let ref_id = get_attr(&derive_input, "ref_name") .map(|a| Ident::new(attr_to_str(a).value().as_str(), Span::call_site())) @@ -210,7 +212,6 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream { impl std::cmp::Eq for #ref_id {} }; - expanded.extend(derives); } diff --git a/rust/tvm-rt/src/array.rs b/rust/tvm-rt/src/array.rs index a6afb54ecd30..1b0ce8399d1f 100644 --- a/rust/tvm-rt/src/array.rs +++ b/rust/tvm-rt/src/array.rs @@ -18,7 +18,7 @@ */ use std::convert::{TryFrom, TryInto}; -use std::iter::{IntoIterator, Iterator, FromIterator}; +use std::iter::{FromIterator, IntoIterator, Iterator}; use std::marker::PhantomData; use crate::errors::Error; @@ -126,12 +126,11 @@ impl IntoIterator for Array { } impl FromIterator for Array { - fn from_iter>(iter: I) -> Self { + fn from_iter>(iter: I) -> Self { Array::from_vec(iter.into_iter().collect()).unwrap() } } - impl From> for ArgValue<'static> { fn from(array: Array) -> ArgValue<'static> { array.object.into() diff --git a/rust/tvm-rt/src/object/object_ptr.rs b/rust/tvm-rt/src/object/object_ptr.rs index 97c474c29b1f..8df6041956b8 100644 --- a/rust/tvm-rt/src/object/object_ptr.rs +++ b/rust/tvm-rt/src/object/object_ptr.rs @@ -364,7 +364,9 @@ impl<'a, T: IsObject> TryFrom> for ObjectPtr { impl std::hash::Hash for ObjectPtr { fn hash(&self, state: &mut H) { - state.write_i64(super::structural_hash(ObjectRef(Some(self.clone().upcast())), false).unwrap()) + state.write_i64( + super::structural_hash(ObjectRef(Some(self.clone().upcast())), false).unwrap(), + ) } } diff --git a/rust/tvm/src/ir/expr.rs b/rust/tvm/src/ir/expr.rs index 018af0520777..653169def3a4 100644 --- a/rust/tvm/src/ir/expr.rs +++ b/rust/tvm/src/ir/expr.rs @@ -26,7 +26,6 @@ use crate::DataType; use super::relay; use super::span::Span; - #[repr(C)] #[derive(Object, Debug)] #[ref_name = "BaseExpr"] diff --git a/rust/tvm/src/ir/module.rs b/rust/tvm/src/ir/module.rs index a49e587b7be9..a09f70dc25b9 100644 --- a/rust/tvm/src/ir/module.rs +++ b/rust/tvm/src/ir/module.rs @@ -28,12 +28,12 @@ use crate::runtime::array::Array; use crate::runtime::function::Result; use crate::runtime::map::Map; use crate::runtime::string::String as TVMString; -use crate::runtime::{external, Object, IsObjectRef}; +use crate::runtime::{external, IsObjectRef, Object}; use super::expr::GlobalVar; -use super::function::{BaseFunc}; +use super::function::BaseFunc; use super::source_map::SourceMap; -use super::{ty::GlobalTypeVar, ty::TypeData, relay}; +use super::{relay, ty::GlobalTypeVar, ty::TypeData}; #[derive(Error, Debug)] pub enum Error { @@ -100,13 +100,16 @@ external! { impl IRModule { pub fn new(funcs: F, types: T) -> Result - where F: IntoIterator, T: IntoIterator { + where + F: IntoIterator, + T: IntoIterator, + { module_new(Map::from_iter(funcs), Map::from_iter(types)) } pub fn empty() -> Result { let funcs = HashMap::::new(); - let types = HashMap::::new(); + let types = HashMap::::new(); IRModule::new(funcs, types) } @@ -128,14 +131,14 @@ impl IRModule { Ok(module) } - pub fn add( - &mut self, - var: GlobalVar, - func: F) -> Result - // todo(@jroesch): can we do better here? why doesn't BaseFunc::Object work? - where F: IsObjectRef, F::Object: AsRef<::Object> { - module_add(self.clone(), var, func.upcast(), true) - } + pub fn add(&mut self, var: GlobalVar, func: F) -> Result + // todo(@jroesch): can we do better here? why doesn't BaseFunc::Object work? + where + F: IsObjectRef, + F::Object: AsRef<::Object>, + { + module_add(self.clone(), var, func.upcast(), true) + } pub fn add_def( &mut self, @@ -147,7 +150,9 @@ impl IRModule { } pub fn get_global_var(&self, name: S) -> Result - where S: Into { + where + S: Into, + { module_get_global_var(self.clone(), name.into()) } @@ -183,7 +188,9 @@ impl IRModule { } pub fn lookup_def_str(&self, global: S) -> Result - where S: Into { + where + S: Into, + { module_lookup_def_str(self.clone(), global.into()) } @@ -192,19 +199,23 @@ impl IRModule { } pub fn from_expr(expr: E) -> Result - where E: IsObjectRef, E::Object: AsRef<::Object> { + where + E: IsObjectRef, + E::Object: AsRef<::Object>, + { Self::from_expr_with_items(expr, HashMap::new(), HashMap::new()) } pub fn from_expr_with_items(expr: E, funcs: F, types: T) -> Result - where F: IntoIterator, - T: IntoIterator, - E: IsObjectRef, - E::Object: AsRef<::Object> { + where + F: IntoIterator, + T: IntoIterator, + E: IsObjectRef, + E::Object: AsRef<::Object>, + { module_from_expr(expr.upcast(), Map::from_iter(funcs), Map::from_iter(types)) } - pub fn import>(&mut self, path: S) -> Result<()> { module_import(self.clone(), path.into()) } @@ -218,9 +229,9 @@ impl IRModule { mod tests { use super::relay::*; use super::*; - use tvm_rt::IsObjectRef; - use crate::ir::ty::{GlobalTypeVar, TypeData, TypeKind}; use crate::ir::span::Span; + use crate::ir::ty::{GlobalTypeVar, TypeData, TypeKind}; + use tvm_rt::IsObjectRef; fn add_dummy_functions(names: Vec<&str>) -> Result { let mut module = IRModule::empty()?; @@ -290,10 +301,11 @@ mod tests { fn test_get_global_vars() -> Result<()> { let names = vec!["foo", "bar", "baz"]; let module = add_dummy_functions(names.clone())?; - let gvars: Vec = - module.get_global_vars()?.into_iter().map(|gv| { - gv.name_hint.as_str().unwrap().to_string() - }).collect(); + let gvars: Vec = module + .get_global_vars()? + .into_iter() + .map(|gv| gv.name_hint.as_str().unwrap().to_string()) + .collect(); for name in names { assert!(gvars.contains(&name.to_string())); @@ -306,10 +318,11 @@ mod tests { fn test_get_global_type_vars() -> Result<()> { let names = vec!["foo", "bar", "baz"]; let module = add_dummy_types(names.clone())?; - let gvars: Vec = - module.get_global_type_vars()?.into_iter().map(|gv| { - gv.name_hint.as_str().unwrap().to_string() - }).collect(); + let gvars: Vec = module + .get_global_type_vars()? + .into_iter() + .map(|gv| gv.name_hint.as_str().unwrap().to_string()) + .collect(); for name in names { assert!(gvars.contains(&name.to_string())); diff --git a/rust/tvm/src/ir/relay/mod.rs b/rust/tvm/src/ir/relay/mod.rs index 8dd96173e5a8..9d2983237acb 100644 --- a/rust/tvm/src/ir/relay/mod.rs +++ b/rust/tvm/src/ir/relay/mod.rs @@ -139,7 +139,11 @@ impl Var { pub fn static_tensor(name_hint: String, sh: Vec, dtype: DataType) -> Var { let sh = Array::from_vec(sh.into_iter().map(Into::into).collect()).unwrap(); - Self::new(name_hint, super::ty::TensorType::new(sh, dtype, Span::null()).upcast(), Span::null()) + Self::new( + name_hint, + super::ty::TensorType::new(sh, dtype, Span::null()).upcast(), + Span::null(), + ) } } @@ -497,9 +501,17 @@ impl Function { } pub fn simple(params: Vec, body: E) -> Function - where E: IsObjectRef, E::Object: AsRef<::Object> { + where + E: IsObjectRef, + E::Object: AsRef<::Object>, + { let params = Array::from_vec(params).unwrap(); - Self::new(params, body.upcast(), Type::null(), Array::from_vec(vec![]).unwrap()) + Self::new( + params, + body.upcast(), + Type::null(), + Array::from_vec(vec![]).unwrap(), + ) } } diff --git a/rust/tvm/src/ir/ty.rs b/rust/tvm/src/ir/ty.rs index f5b9cb12fd98..f7c52b51f332 100644 --- a/rust/tvm/src/ir/ty.rs +++ b/rust/tvm/src/ir/ty.rs @@ -17,14 +17,13 @@ * under the License. */ - use tvm_macros::Object; use tvm_rt::{array::Array, DataType}; -use crate::ir::span::Span; use crate::ir::relay::Constructor; +use crate::ir::span::Span; use crate::ir::PrimExpr; -use crate::runtime::{IsObject, Object, ObjectPtr, string::String as TString}; +use crate::runtime::{string::String as TString, IsObject, Object, ObjectPtr}; #[repr(C)] #[derive(Object, Debug)] @@ -84,7 +83,6 @@ pub struct PointerTypeNode { pub element_type: Type, } - /// Possible kinds of type variables. #[derive(PartialEq, Eq, Debug)] pub enum TypeKind { @@ -127,7 +125,9 @@ pub struct GlobalTypeVarNode { impl GlobalTypeVar { pub fn new(name_hint: S, kind: TypeKind, span: Span) -> GlobalTypeVar - where S: Into { + where + S: Into, + { let node = GlobalTypeVarNode { base: TypeNode::base::(span), name_hint: name_hint.into(), @@ -275,13 +275,19 @@ pub struct TypeDataNode { } impl TypeData { - pub fn new(type_name: GlobalTypeVar, type_vars: TypeVars, constructors: Ctors, span: Span) -> TypeData - where TypeVars: IntoIterator, - Ctors: IntoIterator, + pub fn new( + type_name: GlobalTypeVar, + type_vars: TypeVars, + constructors: Ctors, + span: Span, + ) -> TypeData + where + TypeVars: IntoIterator, + Ctors: IntoIterator, { use std::iter::FromIterator; let type_data = TypeDataNode { - base: TypeNode::base::(span,), + base: TypeNode::base::(span), type_name, type_vars: Array::from_iter(type_vars), constructors: Array::from_iter(constructors), diff --git a/src/ir/module.cc b/src/ir/module.cc index 25c21e81fcf9..0964199cd40a 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -425,9 +425,10 @@ TVM_REGISTER_GLOBAL("ir.Module_Add").set_body([](TVMArgs args, TVMRetValue* ret) *ret = mod; }); -TVM_REGISTER_GLOBAL("ir.Module_AddDef").set_body_typed([](IRModule module, GlobalTypeVar var, TypeData type_def, bool update) { - module->AddTypeDef(var, type_def, update); -}); +TVM_REGISTER_GLOBAL("ir.Module_AddDef") + .set_body_typed([](IRModule module, GlobalTypeVar var, TypeData type_def, bool update) { + module->AddTypeDef(var, type_def, update); + }); TVM_REGISTER_GLOBAL("ir.Module_GetGlobalVar") .set_body_method(&IRModuleNode::GetGlobalVar); @@ -444,7 +445,6 @@ TVM_REGISTER_GLOBAL("ir.Module_ContainGlobalVar") TVM_REGISTER_GLOBAL("ir.Module_ContainGlobalTypeVar") .set_body_method(&IRModuleNode::ContainGlobalTypeVar); - TVM_REGISTER_GLOBAL("ir.Module_GetGlobalTypeVar") .set_body_method(&IRModuleNode::GetGlobalTypeVar); From 6875ac7b8aff2e6df888b24f507d3902b0c8dcf5 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Wed, 4 Nov 2020 19:12:32 -0800 Subject: [PATCH 11/12] Restore module.cc --- src/ir/module.cc | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/ir/module.cc b/src/ir/module.cc index 0964199cd40a..7990b281fb04 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -425,10 +425,7 @@ TVM_REGISTER_GLOBAL("ir.Module_Add").set_body([](TVMArgs args, TVMRetValue* ret) *ret = mod; }); -TVM_REGISTER_GLOBAL("ir.Module_AddDef") - .set_body_typed([](IRModule module, GlobalTypeVar var, TypeData type_def, bool update) { - module->AddTypeDef(var, type_def, update); - }); +TVM_REGISTER_GLOBAL("ir.Module_AddDef").set_body_method(&IRModuleNode::AddTypeDef); TVM_REGISTER_GLOBAL("ir.Module_GetGlobalVar") .set_body_method(&IRModuleNode::GetGlobalVar); From da2ab07031fd0019cc6bce37ca53730c2def0317 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Wed, 4 Nov 2020 20:46:19 -0800 Subject: [PATCH 12/12] Bump syn --- rust/tvm-macros/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rust/tvm-macros/Cargo.toml b/rust/tvm-macros/Cargo.toml index 8e97d3b670d8..e491177d8599 100644 --- a/rust/tvm-macros/Cargo.toml +++ b/rust/tvm-macros/Cargo.toml @@ -33,5 +33,5 @@ proc-macro = true goblin = "^0.2" proc-macro2 = "^1.0" quote = "^1.0" -syn = { version = "^1.0", features = ["full", "parsing", "extra-traits"] } +syn = { version = "1.0.48", features = ["full", "parsing", "extra-traits"] } proc-macro-error = "^1.0"