From cfa76f5d8604821dbe03f4fc9b4dc9a8726f31a1 Mon Sep 17 00:00:00 2001 From: Andrew Liu Date: Thu, 10 Dec 2020 11:23:23 -0800 Subject: [PATCH 1/6] more rust bindings --- rust/tvm/src/ir/relay/attrs/nn.rs | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/rust/tvm/src/ir/relay/attrs/nn.rs b/rust/tvm/src/ir/relay/attrs/nn.rs index 7ecd92febc22..37115c916fda 100644 --- a/rust/tvm/src/ir/relay/attrs/nn.rs +++ b/rust/tvm/src/ir/relay/attrs/nn.rs @@ -106,3 +106,25 @@ pub struct BatchNormAttrsNode { pub center: bool, pub scale: bool, } + +#[repr(C)] +#[derive(Object, Debug)] +#[ref_name = "LeakyReluAttrs"] +#[type_key = "relay.attrs.LeakyReluAttrs"] +pub struct LeakyReluAttrsNode { + pub base: BaseAttrsNode, + pub alpha: f64 +} + +#[repr(C)] +#[derive(Object, Debug)] +#[ref_name = "AvgPool2DAttrs"] +#[type_key = "relay.attrs.AvgPool2DAttrs"] +pub struct AvgPool2DAttrsNode { + pub pool_size: Array, + pub strides: Array, + pub padding: Array, + pub layout: TString, + pub ceil_mode: bool, + pub count_include_pad: bool +} From f154ffa4970a60fa46211a57e5b5b3905c79f959 Mon Sep 17 00:00:00 2001 From: Andrew Liu Date: Thu, 10 Dec 2020 11:25:42 -0800 Subject: [PATCH 2/6] add base --- rust/tvm/src/ir/relay/attrs/nn.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/rust/tvm/src/ir/relay/attrs/nn.rs b/rust/tvm/src/ir/relay/attrs/nn.rs index 37115c916fda..ff523dcb0302 100644 --- a/rust/tvm/src/ir/relay/attrs/nn.rs +++ b/rust/tvm/src/ir/relay/attrs/nn.rs @@ -121,6 +121,7 @@ pub struct LeakyReluAttrsNode { #[ref_name = "AvgPool2DAttrs"] #[type_key = "relay.attrs.AvgPool2DAttrs"] pub struct AvgPool2DAttrsNode { + pub base: BaseAttrsNode, pub pool_size: Array, pub strides: Array, pub padding: Array, From 5a8c94f8f26a5fc07e00dd34e44e3a9da7f28186 Mon Sep 17 00:00:00 2001 From: Andrew Liu Date: Sat, 12 Dec 2020 12:44:01 -0800 Subject: [PATCH 3/6] more rust bindings fix --- rust/tvm/src/ir/relay/attrs/nn.rs | 13 ++++++ rust/tvm/src/ir/relay/attrs/transform.rs | 52 ++++++++++++++++++++++++ 2 files changed, 65 insertions(+) diff --git a/rust/tvm/src/ir/relay/attrs/nn.rs b/rust/tvm/src/ir/relay/attrs/nn.rs index ff523dcb0302..41e28f2a281f 100644 --- a/rust/tvm/src/ir/relay/attrs/nn.rs +++ b/rust/tvm/src/ir/relay/attrs/nn.rs @@ -129,3 +129,16 @@ pub struct AvgPool2DAttrsNode { pub ceil_mode: bool, pub count_include_pad: bool } + +#[repr(C)] +#[derive(Object, Debug)] +#[ref_name = "UpSamplingAttrs"] +#[type_key = "relay.attrs.UpSamplingAttrs"] +pub struct UpSamplingAttrsNode { + pub base: BaseAttrsNode, + pub scale_h: f64, + pub scale_w: f64, + pub layout: TString, + pub method: TString, + pub align_corners: bool +} diff --git a/rust/tvm/src/ir/relay/attrs/transform.rs b/rust/tvm/src/ir/relay/attrs/transform.rs index c459f96b2d2f..aafd258a4a48 100644 --- a/rust/tvm/src/ir/relay/attrs/transform.rs +++ b/rust/tvm/src/ir/relay/attrs/transform.rs @@ -18,8 +18,13 @@ */ use crate::ir::attrs::BaseAttrsNode; +use crate::ir::PrimExpr; +use crate::runtime::array::Array; +use crate::runtime::ObjectRef; use tvm_macros::Object; +type IndexExpr = PrimExpr; + #[repr(C)] #[derive(Object, Debug)] #[ref_name = "ExpandDimsAttrs"] @@ -29,3 +34,50 @@ pub struct ExpandDimsAttrsNode { pub axis: i32, pub num_newaxis: i32, } + +#[repr(C)] +#[derive(Object, Debug)] +#[ref_name = "ConcatenateAttrs"] +#[type_key = "relay.attrs.ConcatenateAttrs"] +pub struct ConcatenateAttrsNode { + pub base: BaseAttrsNode, + pub axis: i32 +} + +#[repr(C)] +#[derive(Object, Debug)] +#[ref_name = "ReshapeAttrs"] +#[type_key = "relay.attrs.ReshapeAttrs"] +pub struct ReshapeAttrsNode { + pub base: BaseAttrsNode, + pub newshape: Array, + pub reverse: bool +} + +#[repr(C)] +#[derive(Object, Debug)] +#[ref_name = "SplitAttrs"] +#[type_key = "relay.attrs.SplitAttrs"] +pub struct SplitAttrsNode { + pub base: BaseAttrsNode, + pub indices_or_sections: ObjectRef, + pub axis: i32 +} + +#[repr(C)] +#[derive(Object, Debug)] +#[ref_name = "TransposeAttrs"] +#[type_key = "relay.attrs.TransposeAttrs"] +pub struct TransposeAttrsNode { + pub base: BaseAttrsNode, + pub axes: Array +} + +#[repr(C)] +#[derive(Object, Debug)] +#[ref_name = "SqueezeAttrs"] +#[type_key = "relay.attrs.SqueezeAttrs"] +pub struct SqueezeAttrsNode { + pub base: BaseAttrsNode, + pub axis: Array +} From 0225fcc6081df2c225cab214ad3bc6360e46d4d9 Mon Sep 17 00:00:00 2001 From: Andrew Liu Date: Tue, 15 Dec 2020 12:02:06 -0800 Subject: [PATCH 4/6] cargo fmt --- rust/tvm/src/ir/relay/attrs/nn.rs | 16 ++++++++-------- rust/tvm/src/ir/relay/attrs/transform.rs | 24 ++++++++++++------------ 2 files changed, 20 insertions(+), 20 deletions(-) diff --git a/rust/tvm/src/ir/relay/attrs/nn.rs b/rust/tvm/src/ir/relay/attrs/nn.rs index 41e28f2a281f..f0137fa3cbcc 100644 --- a/rust/tvm/src/ir/relay/attrs/nn.rs +++ b/rust/tvm/src/ir/relay/attrs/nn.rs @@ -113,7 +113,7 @@ pub struct BatchNormAttrsNode { #[type_key = "relay.attrs.LeakyReluAttrs"] pub struct LeakyReluAttrsNode { pub base: BaseAttrsNode, - pub alpha: f64 + pub alpha: f64, } #[repr(C)] @@ -127,7 +127,7 @@ pub struct AvgPool2DAttrsNode { pub padding: Array, pub layout: TString, pub ceil_mode: bool, - pub count_include_pad: bool + pub count_include_pad: bool, } #[repr(C)] @@ -135,10 +135,10 @@ pub struct AvgPool2DAttrsNode { #[ref_name = "UpSamplingAttrs"] #[type_key = "relay.attrs.UpSamplingAttrs"] pub struct UpSamplingAttrsNode { - pub base: BaseAttrsNode, - pub scale_h: f64, - pub scale_w: f64, - pub layout: TString, - pub method: TString, - pub align_corners: bool + pub base: BaseAttrsNode, + pub scale_h: f64, + pub scale_w: f64, + pub layout: TString, + pub method: TString, + pub align_corners: bool, } diff --git a/rust/tvm/src/ir/relay/attrs/transform.rs b/rust/tvm/src/ir/relay/attrs/transform.rs index aafd258a4a48..b5f7c2047d62 100644 --- a/rust/tvm/src/ir/relay/attrs/transform.rs +++ b/rust/tvm/src/ir/relay/attrs/transform.rs @@ -40,8 +40,8 @@ pub struct ExpandDimsAttrsNode { #[ref_name = "ConcatenateAttrs"] #[type_key = "relay.attrs.ConcatenateAttrs"] pub struct ConcatenateAttrsNode { - pub base: BaseAttrsNode, - pub axis: i32 + pub base: BaseAttrsNode, + pub axis: i32, } #[repr(C)] @@ -49,9 +49,9 @@ pub struct ConcatenateAttrsNode { #[ref_name = "ReshapeAttrs"] #[type_key = "relay.attrs.ReshapeAttrs"] pub struct ReshapeAttrsNode { - pub base: BaseAttrsNode, - pub newshape: Array, - pub reverse: bool + pub base: BaseAttrsNode, + pub newshape: Array, + pub reverse: bool, } #[repr(C)] @@ -59,9 +59,9 @@ pub struct ReshapeAttrsNode { #[ref_name = "SplitAttrs"] #[type_key = "relay.attrs.SplitAttrs"] pub struct SplitAttrsNode { - pub base: BaseAttrsNode, - pub indices_or_sections: ObjectRef, - pub axis: i32 + pub base: BaseAttrsNode, + pub indices_or_sections: ObjectRef, + pub axis: i32, } #[repr(C)] @@ -69,8 +69,8 @@ pub struct SplitAttrsNode { #[ref_name = "TransposeAttrs"] #[type_key = "relay.attrs.TransposeAttrs"] pub struct TransposeAttrsNode { - pub base: BaseAttrsNode, - pub axes: Array + pub base: BaseAttrsNode, + pub axes: Array, } #[repr(C)] @@ -78,6 +78,6 @@ pub struct TransposeAttrsNode { #[ref_name = "SqueezeAttrs"] #[type_key = "relay.attrs.SqueezeAttrs"] pub struct SqueezeAttrsNode { - pub base: BaseAttrsNode, - pub axis: Array + pub base: BaseAttrsNode, + pub axis: Array, } From 810be3d9b328c02e9ad7908a2e2cccb04f218284 Mon Sep 17 00:00:00 2001 From: Andrew Liu Date: Sun, 20 Dec 2020 19:24:25 -0800 Subject: [PATCH 5/6] fix upsampling attrs --- include/tvm/relay/attrs/nn.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h index f8aa1fc508b6..0974f8233929 100644 --- a/include/tvm/relay/attrs/nn.h +++ b/include/tvm/relay/attrs/nn.h @@ -959,8 +959,8 @@ struct FIFOBufferAttrs : public tvm::AttrsNode { struct UpSamplingAttrs : public tvm::AttrsNode { double scale_h; double scale_w; - std::string layout; - std::string method; + tvm::String layout; + tvm::String method; bool align_corners; TVM_DECLARE_ATTRS(UpSamplingAttrs, "relay.attrs.UpSamplingAttrs") { From 23475b5c8d7ccbd7a9f3609fe6fe21c3fbe2b22f Mon Sep 17 00:00:00 2001 From: Andrew Liu Date: Wed, 23 Dec 2020 21:00:12 -0800 Subject: [PATCH 6/6] fix avgpool2d attrs --- include/tvm/relay/attrs/nn.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h index 0974f8233929..5fb45c934536 100644 --- a/include/tvm/relay/attrs/nn.h +++ b/include/tvm/relay/attrs/nn.h @@ -715,7 +715,7 @@ struct AvgPool2DAttrs : public tvm::AttrsNode { Array pool_size; Array strides; Array padding; - std::string layout; + tvm::String layout; bool ceil_mode; bool count_include_pad;