Skip to content

Commit 1fa77f2

Browse files
MasterJH5574tqchen
authored andcommitted
[Unity] Relax op: neural networks (#13993)
This PR is about the high-level tensor computation operators in Relax. This PR includes the neural network operators.
1 parent ba81ffa commit 1fa77f2

File tree

17 files changed

+3536
-0
lines changed

17 files changed

+3536
-0
lines changed

include/tvm/relax/attrs/nn.h

Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file tvm/relax/attrs/nn.h
22+
* \brief Attributes for neural network operators.
23+
*/
24+
#ifndef TVM_RELAX_ATTRS_NN_H_
25+
#define TVM_RELAX_ATTRS_NN_H_
26+
27+
#include <tvm/relax/expr.h>
28+
29+
namespace tvm {
30+
namespace relax {
31+
32+
/*! \brief Attributes used in Conv2d operator */
33+
struct Conv2DAttrs : public tvm::AttrsNode<Conv2DAttrs> {
34+
Array<IntImm> strides;
35+
Array<IntImm> padding;
36+
Array<IntImm> dilation;
37+
int groups;
38+
String data_layout;
39+
String kernel_layout;
40+
String out_layout;
41+
DataType out_dtype;
42+
43+
TVM_DECLARE_ATTRS(Conv2DAttrs, "relax.attrs.Conv2DAttrs") {
44+
TVM_ATTR_FIELD(strides).describe("Specifies the strides of the convolution.");
45+
TVM_ATTR_FIELD(padding).describe(
46+
"If padding is non-zero, then the input is implicitly zero-padded"
47+
"Padding support both symmetric and asymmetric as"
48+
"one int : same padding used on all sides"
49+
"two int : bottom, right will use same padding as top, left"
50+
"four int : padding width in the order of (top, left, bottom, right)");
51+
TVM_ATTR_FIELD(dilation).describe(
52+
"Specifies the dilation rate to use for dilated convolution.");
53+
TVM_ATTR_FIELD(groups).describe(
54+
"Number of groups to split the input into for grouped convolution. The number of input and "
55+
"output channels should be divisible by the number of groups.");
56+
TVM_ATTR_FIELD(data_layout)
57+
.describe(
58+
"Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc."
59+
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
60+
"dimensions respectively. Convolution is applied on the 'H' and"
61+
"'W' dimensions.");
62+
TVM_ATTR_FIELD(kernel_layout)
63+
.describe(
64+
"Dimension ordering of weight. Can be 'OIHW', 'OIHW16o16i', etc."
65+
"'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width"
66+
"dimensions respectively.");
67+
TVM_ATTR_FIELD(out_layout)
68+
.describe(
69+
"Dimension ordering of output. Can be 'NCHW', 'NHWC', etc."
70+
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
71+
"dimensions respectively. Default to be same as input layout.");
72+
TVM_ATTR_FIELD(out_dtype).describe(
73+
"Output data type, set to explicit type under mixed precision setting");
74+
}
75+
}; // struct Conv2dAttrs
76+
77+
/*! \brief Attributes used in max_pool2d operator */
78+
struct MaxPool2DAttrs : public tvm::AttrsNode<MaxPool2DAttrs> {
79+
Array<IntImm> pool_size;
80+
Array<IntImm> strides;
81+
Array<IntImm> padding;
82+
Array<IntImm> dilation;
83+
bool ceil_mode;
84+
String layout;
85+
String out_layout;
86+
87+
TVM_DECLARE_ATTRS(MaxPool2DAttrs, "relax.attrs.MaxPool2DAttrs") {
88+
TVM_ATTR_FIELD(pool_size).describe("Size of the pooling windows.");
89+
TVM_ATTR_FIELD(strides).describe("Specifies the strides of the convolution.");
90+
TVM_ATTR_FIELD(dilation).describe("Specifies the dilation of the convolution.");
91+
TVM_ATTR_FIELD(padding).describe(
92+
"If padding is non-zero, then the input is implicitly zero-padded"
93+
"Padding support both symmetric and asymmetric as"
94+
"one int : same padding used on all sides"
95+
"two int : bottom, right will use same padding as top, left"
96+
"four int : padding width in the order of (top, left, bottom, right)");
97+
TVM_ATTR_FIELD(ceil_mode).describe(
98+
"A boolean indicating if use ceil or floor to compute the output shape. By using ceil, "
99+
"every element in the input tensor will be covered by a sliding window.");
100+
TVM_ATTR_FIELD(layout).describe(
101+
"Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc."
102+
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
103+
"dimensions respectively. Pooling is applied on the 'H' and"
104+
"'W' dimensions.");
105+
TVM_ATTR_FIELD(out_layout)
106+
.describe(
107+
"Dimension ordering of output data. Can be 'NCHW', 'NHWC', etc."
108+
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
109+
"dimensions respectively. Pooling is applied on the 'H' and"
110+
"'W' dimensions.");
111+
}
112+
}; // struct MaxPool2dAttrs
113+
114+
/*! \brief Attributes for 2d adaptive pool operator */
115+
struct AdaptivePool2DAttrs : public tvm::AttrsNode<AdaptivePool2DAttrs> {
116+
Optional<Array<IntImm>> output_size;
117+
String layout;
118+
String out_layout;
119+
120+
TVM_DECLARE_ATTRS(AdaptivePool2DAttrs, "relax.attrs.AdaptivePool2DAttrs") {
121+
TVM_ATTR_FIELD(output_size).describe("Output height and width.");
122+
TVM_ATTR_FIELD(layout).describe(
123+
"Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc."
124+
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
125+
"dimensions respectively. Pooling is applied on the 'H' and"
126+
"'W' dimensions.");
127+
TVM_ATTR_FIELD(out_layout)
128+
.describe(
129+
"Dimension ordering of output data. Can be 'NCHW', 'NHWC', etc."
130+
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
131+
"dimensions respectively. Pooling is applied on the 'H' and"
132+
"'W' dimensions.");
133+
}
134+
}; // struct AdaptivePool2DAttrs
135+
136+
/*! \brief Attributes used in softmax operators */
137+
struct SoftmaxAttrs : public tvm::AttrsNode<SoftmaxAttrs> {
138+
int axis;
139+
140+
TVM_DECLARE_ATTRS(SoftmaxAttrs, "relax.attrs.SoftmaxAttrs") {
141+
TVM_ATTR_FIELD(axis).describe("The axis to sum over when computing softmax.");
142+
}
143+
};
144+
145+
/*! \brief Attributes used in batch_norm operator */
146+
struct BatchNormAttrs : public tvm::AttrsNode<BatchNormAttrs> {
147+
int axis;
148+
double epsilon;
149+
bool center;
150+
bool scale;
151+
152+
TVM_DECLARE_ATTRS(BatchNormAttrs, "relax.attrs.BatchNormAttrs") {
153+
TVM_ATTR_FIELD(axis).describe("The axis along which the normalization is applied.");
154+
TVM_ATTR_FIELD(epsilon).describe("Small float added to variance to avoid dividing by zero");
155+
TVM_ATTR_FIELD(center).describe(
156+
"Indicating if the beta offset will be added to the normalized tensor.");
157+
TVM_ATTR_FIELD(scale).describe("Indicating if the gamma scale will be multiplied.");
158+
}
159+
}; // struct BatchNormAttrs
160+
161+
/*! \brief Attributes used in layer_norm operator */
162+
struct LayerNormAttrs : public tvm::AttrsNode<LayerNormAttrs> {
163+
Array<Integer> axes;
164+
double epsilon;
165+
bool center;
166+
bool scale;
167+
168+
TVM_DECLARE_ATTRS(LayerNormAttrs, "relax.attrs.LayerNormAttrs") {
169+
TVM_ATTR_FIELD(axes).describe("The axes that along which the normalization is applied.");
170+
TVM_ATTR_FIELD(epsilon).describe("Small float added to variance to avoid dividing by zero");
171+
TVM_ATTR_FIELD(center).describe(
172+
"Indicating if the beta offset will be added to the normalized tensor.");
173+
TVM_ATTR_FIELD(scale).describe("Indicating if the gamma scale will be multiplied.");
174+
}
175+
}; // struct LayerNormAttrs
176+
177+
/*! \brief Attributes used in dropout operator */
178+
struct DropoutAttrs : public tvm::AttrsNode<DropoutAttrs> {
179+
double rate;
180+
181+
TVM_DECLARE_ATTRS(DropoutAttrs, "relax.attrs.DropoutAttrs") {
182+
TVM_ATTR_FIELD(rate).describe(
183+
"Fraction of the input that gets dropped out during training time");
184+
}
185+
}; // struct DropoutAttrs
186+
187+
} // namespace relax
188+
} // namespace tvm
189+
190+
#endif // TVM_RELAX_ATTRS_NN_H_

python/tvm/relax/op/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from . import builtin
3232
from . import image
3333
from . import memory
34+
from . import nn
3435

3536

3637
def _register_op_make():

python/tvm/relax/op/nn/__init__.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
# pylint: disable=wildcard-import
18+
"""Neural network related operators."""
19+
from .nn import *

python/tvm/relax/op/nn/_ffi_api.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""Constructor APIs"""
18+
import tvm._ffi
19+
20+
tvm._ffi._init_api("relax.op.nn", __name__)

0 commit comments

Comments
 (0)