Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions include/tvm/relax/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,49 @@
namespace tvm {
namespace relax {

/*! \brief Attributes used in Conv1d operator */
struct Conv1DAttrs : public tvm::AttrsNode<Conv1DAttrs> {
Array<IntImm> strides;
Array<IntImm> padding;
Array<IntImm> dilation;
int groups;
String data_layout;
String kernel_layout;
String out_layout;
DataType out_dtype;

TVM_DECLARE_ATTRS(Conv1DAttrs, "relax.attrs.Conv1DAttrs") {
TVM_ATTR_FIELD(strides).describe("Specifies the strides of the convolution.");
TVM_ATTR_FIELD(padding).describe(
"If padding is non-zero, then the input is implicitly zero-padded"
"Padding support both symmetric and asymmetric as"
"one int : same padding used on both sides"
"two int : padding width in the order of (left, right)");
TVM_ATTR_FIELD(dilation).describe(
"Specifies the dilation rate to use for dilated convolution.");
TVM_ATTR_FIELD(groups).describe(
"Number of groups to split the input into for grouped convolution. The number of input and "
"output channels should be divisible by the number of groups.");
TVM_ATTR_FIELD(data_layout)
.describe(
"Dimension ordering of input data. Can be 'NCW', 'NWC', etc."
"'N', 'C', 'W' stands for batch, channel, width"
"dimensions respectively. Convolution is applied on the 'W' dimensions.");
TVM_ATTR_FIELD(kernel_layout)
.describe(
"Dimension ordering of weight. Can be 'OIW', 'IOW', etc."
"'O', 'I', 'W' stands for num_filter, input_channel, and width"
"dimensions respectively.");
TVM_ATTR_FIELD(out_layout)
.describe(
"Dimension ordering of output. Can be 'NCW', 'NWC', etc."
"'N', 'C', 'W' stands for batch, channel, and width"
"dimensions respectively. Default to be same as input layout.");
TVM_ATTR_FIELD(out_dtype).describe(
"Output data type, set to explicit type under mixed precision setting");
}
}; // struct Conv1dAttrs

/*! \brief Attributes used in Conv2d operator */
struct Conv2DAttrs : public tvm::AttrsNode<Conv2DAttrs> {
Array<IntImm> strides;
Expand Down
29 changes: 29 additions & 0 deletions python/tvm/relax/frontend/torch/fx_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,6 +637,34 @@ def _linear(self, node: fx.node.Node) -> relax.Var:
bias = None if module.bias is None else self.params[module.bias]
return self.block_builder.emit(relax.op.linear(x, weight, bias, "float32"))

def _conv1d(self, node: fx.node.Node) -> relax.Var:
x = self.env[node.args[0]]
module = self.named_modules[node.target]
weight = self.params[module.weight]

conv1d = self.block_builder.emit(
relax.op.nn.conv1d(
x,
weight,
strides=module.stride,
padding=module.padding,
dilation=module.dilation,
groups=module.groups,
data_layout="NCW",
kernel_layout="OIW",
out_dtype="float32",
)
)

if module.bias is None:
return conv1d

bias = self.params[module.bias]
assert len(self.shape_of(bias)) == 1
bias = relax.op.reshape(bias, (1, -1, 1))

return self.block_builder.emit(relax.op.add(conv1d, bias))

def _conv2d(self, node: fx.node.Node) -> relax.Var:
x = self.env[node.args[0]]
module = self.named_modules[node.target]
Expand Down Expand Up @@ -1001,6 +1029,7 @@ def create_convert_map(self):
self.convert_map: Dict[Union[nn.Module, str], Callable[[fx.node.Node], relax.Var]] = {
# call_module
nn.Linear: self._linear,
nn.Conv1d: self._conv1d,
nn.Conv2d: self._conv2d,
nn.MaxPool2d: self._max_pool2d,
nn.AdaptiveAvgPool2d: self._adaptive_avg_pool2d(is_module=True),
Expand Down
98 changes: 98 additions & 0 deletions python/tvm/relax/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,104 @@
from ...expr import Expr


def conv1d(
data: Expr,
weight: Expr,
strides: Union[int, Tuple[int]] = 1,
padding: Union[int, Tuple[int, ...]] = 0,
dilation: Union[int, Tuple[int]] = 1,
groups: int = 1,
data_layout: str = "NCW",
kernel_layout: str = "OIW",
out_layout: Optional[str] = None,
out_dtype: Optional[Union[str, DataType]] = None,
) -> Expr:
r"""1D convolution.

This operator takes the weight as the 1D convolution kernel
and convolves it with data to produce an output.


In the default case, where the data_layout is `NCW`
and kernel_layout is `OIW`, conv1d takes in
a data Tensor with shape `(batch_size, in_channels, width)`,
and a weight Tensor with shape `(channels, in_channels, kernel_w)`,
where `kernel_w` is the length of the `W` kernel dimension,
to produce an output Tensor with the following rule:

.. math::

\mbox{out}[b, c, x] = \sum_{dx, k}
\mbox{data}[b, k, \mbox{strides} * x + dx] *
\mbox{weight}[c, k, dx]

Padding and dilation are applied to data and weight respectively before the computation.
This operator accepts data layout specification.
Semantically, the operator will convert the layout to the canonical layout
(`NCW` for data and `OIW` for weight), perform the computation,
then convert to the out_layout.

Parameters
----------
data : relax.Expr
The input data to the operator.

weight : relax.Expr
The weight expressions.

strides : Union[int, Tuple[int]]
The strides of convolution. It is required to have length 1.

padding : Union[int, Tuple[int, ...]]
The padding of convolution on both sides of inputs before convolution.
It is required to have length either 1 or 2.

dilation : Union[int, Tuple[int, int]]
Specifies the dilation rate to be used for dilated convolution.
It is required to have length 1.

groups : int
Number of groups to split the input into for grouped convolution.
The number of input and output channels should be divisible by the number of groups.

data_layout : str
Layout of the input.

kernel_layout : str
Layout of the weight.

out_layout : Optional[str]
Layout of the output. If not specified, it is the same as data_layout

out_dtype : Optional[Union[str, DataType]]
Specifies the output data type for mixed precision conv1d.

Returns
-------
result : relax.Expr
The computed result.
"""
if isinstance(strides, int):
strides = (strides,)
if isinstance(dilation, int):
dilation = (dilation,)
if isinstance(padding, int):
padding = (padding, padding)

return _ffi_api.conv1d( # type: ignore
data,
weight,
strides,
padding,
dilation,
groups,
data_layout,
kernel_layout,
out_layout,
out_dtype,
)


def conv2d(
data: Expr,
weight: Expr,
Expand Down
40 changes: 40 additions & 0 deletions python/tvm/relax/transform/legalize_ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,46 @@
from .common import register_legalize, _call_topi_without_attr


@register_legalize("relax.nn.conv1d")
def _nn_conv1d(bb: BlockBuilder, call: Call) -> Expr:
if call.attrs.out_layout != call.attrs.data_layout:
logging.info(
"TOPI conv1d does not support different input-output "
"layouts, and thus cannot be legalized by TOPI"
)
return call
if len(call.attrs.data_layout) != 3 or len(call.attrs.kernel_layout) != 3:
logging.info(
"Conv1D where data layout or kernel layout have channel chunk "
"cannot be legalized by TOPI at this moment."
)
return call
if call.attrs.groups != 1:
data_layout = tir.layout(call.attrs.data_layout)
kernel_layout = tir.layout(call.attrs.kernel_layout)
ic = call.args[0].struct_info.shape.values[data_layout.index_of("C")]
oc = call.args[1].struct_info.shape.values[kernel_layout.index_of("O")]
if not isinstance(ic, tir.IntImm) or not isinstance(oc, tir.IntImm):
logging.info(
"Conv1D where number of groups is more than one and input or output "
"channel size is symbolic cannot be legalized by TOPI at this moment."
)
return call

return bb.call_te(
topi.nn.conv1d,
data=call.args[0],
kernel=call.args[1],
strides=call.attrs.strides,
padding=call.attrs.padding,
dilation=call.attrs.dilation,
data_layout=call.attrs.data_layout,
kernel_layout=call.attrs.kernel_layout,
out_dtype=call.attrs.out_dtype if call.attrs.out_dtype != "" else None,
primfunc_name_hint="conv1d",
)


@register_legalize("relax.nn.conv2d")
def _nn_conv2d(bb: BlockBuilder, call: Call) -> Expr:
if call.attrs.out_layout != call.attrs.data_layout:
Expand Down
Loading