From 7cb590e7ddd8f4cf3ff7a9e50f4d94cbe77c640f Mon Sep 17 00:00:00 2001 From: Lesheng Jin Date: Wed, 22 Mar 2023 17:48:46 -0700 Subject: [PATCH 1/5] add conv1d to op --- include/tvm/relax/attrs/nn.h | 43 ++ python/tvm/relax/op/nn/nn.py | 98 +++++ src/relax/op/nn/convolution.cc | 155 +++++++ src/relax/op/nn/convolution.h | 5 + src/relax/op/op_common.h | 20 + tests/python/relax/test_op_nn_convolution.py | 378 +++++++++++++++++- .../relax/test_tvmscript_parser_op_nn.py | 18 + 7 files changed, 716 insertions(+), 1 deletion(-) diff --git a/include/tvm/relax/attrs/nn.h b/include/tvm/relax/attrs/nn.h index f49cb6b1215b..3daa32fd76b6 100644 --- a/include/tvm/relax/attrs/nn.h +++ b/include/tvm/relax/attrs/nn.h @@ -29,6 +29,49 @@ namespace tvm { namespace relax { +/*! \brief Attributes used in Conv1d operator */ +struct Conv1DAttrs : public tvm::AttrsNode { + Array strides; + Array padding; + Array dilation; + int groups; + String data_layout; + String kernel_layout; + String out_layout; + DataType out_dtype; + + TVM_DECLARE_ATTRS(Conv1DAttrs, "relax.attrs.Conv1DAttrs") { + TVM_ATTR_FIELD(strides).describe("Specifies the strides of the convolution."); + TVM_ATTR_FIELD(padding).describe( + "If padding is non-zero, then the input is implicitly zero-padded" + "Padding support both symmetric and asymmetric as" + "one int : same padding used on both sides" + "two int : padding width in the order of (left, right)"); + TVM_ATTR_FIELD(dilation).describe( + "Specifies the dilation rate to use for dilated convolution."); + TVM_ATTR_FIELD(groups).describe( + "Number of groups to split the input into for grouped convolution. The number of input and " + "output channels should be divisible by the number of groups."); + TVM_ATTR_FIELD(data_layout) + .describe( + "Dimension ordering of input data. Can be 'NCW', 'NWC', etc." + "'N', 'C', 'W' stands for batch, channel, width" + "dimensions respectively. Convolution is applied on the 'W' dimensions."); + TVM_ATTR_FIELD(kernel_layout) + .describe( + "Dimension ordering of weight. Can be 'OIW', 'IOW', etc." + "'O', 'I', 'W' stands for num_filter, input_channel, and width" + "dimensions respectively."); + TVM_ATTR_FIELD(out_layout) + .describe( + "Dimension ordering of output. Can be 'NCW', 'NWC', etc." + "'N', 'C', 'W' stands for batch, channel, and width" + "dimensions respectively. Default to be same as input layout."); + TVM_ATTR_FIELD(out_dtype).describe( + "Output data type, set to explicit type under mixed precision setting"); + } +}; // struct Conv1dAttrs + /*! \brief Attributes used in Conv2d operator */ struct Conv2DAttrs : public tvm::AttrsNode { Array strides; diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py index c774bbc92651..e1d41c6cdfd6 100644 --- a/python/tvm/relax/op/nn/nn.py +++ b/python/tvm/relax/op/nn/nn.py @@ -23,6 +23,104 @@ from ...expr import Expr +def conv1d( + data: Expr, + weight: Expr, + strides: Union[int, Tuple[int]] = 1, + padding: Union[int, Tuple[int, ...]] = 0, + dilation: Union[int, Tuple[int]] = 1, + groups: int = 1, + data_layout: str = "NCW", + kernel_layout: str = "OIW", + out_layout: Optional[str] = None, + out_dtype: Optional[Union[str, DataType]] = None, +) -> Expr: + r"""1D convolution. + + This operator takes the weight as the 1D convolution kernel + and convolves it with data to produce an output. + + + In the default case, where the data_layout is `NCW` + and kernel_layout is `OIW`, conv1d takes in + a data Tensor with shape `(batch_size, in_channels, width)`, + and a weight Tensor with shape `(channels, in_channels, kernel_w)`, + where `kernel_w` is the length of the `W` kernel dimension, + to produce an output Tensor with the following rule: + + .. math:: + + \mbox{out}[b, c, x] = \sum_{dx, k} + \mbox{data}[b, k, \mbox{strides} * x + dx] * + \mbox{weight}[c, k, dx] + + Padding and dilation are applied to data and weight respectively before the computation. + This operator accepts data layout specification. + Semantically, the operator will convert the layout to the canonical layout + (`NCW` for data and `OIW` for weight), perform the computation, + then convert to the out_layout. + + Parameters + ---------- + data : relax.Expr + The input data to the operator. + + weight : relax.Expr + The weight expressions. + + strides : Union[int, Tuple[int]] + The strides of convolution. It is required to have length 1. + + padding : Union[int, Tuple[int, ...]] + The padding of convolution on both sides of inputs before convolution. + It is required to have length either 1 or 2. + + dilation : Union[int, Tuple[int, int]] + Specifies the dilation rate to be used for dilated convolution. + It is required to have length 1. + + groups : int + Number of groups to split the input into for grouped convolution. + The number of input and output channels should be divisible by the number of groups. + + data_layout : str + Layout of the input. + + kernel_layout : str + Layout of the weight. + + out_layout : Optional[str] + Layout of the output. If not specified, it is the same as data_layout + + out_dtype : Optional[Union[str, DataType]] + Specifies the output data type for mixed precision conv1d. + + Returns + ------- + result : relax.Expr + The computed result. + """ + if isinstance(strides, int): + strides = (strides,) + if isinstance(dilation, int): + dilation = (dilation,) + if isinstance(padding, int): + padding = (padding, padding) + + return _ffi_api.conv1d( # type: ignore + data, + weight, + strides, + padding, + dilation, + groups, + data_layout, + kernel_layout, + out_layout, + out_dtype, + ) + + def conv2d( data: Expr, weight: Expr, diff --git a/src/relax/op/nn/convolution.cc b/src/relax/op/nn/convolution.cc index e10d205b23b6..ae84409c2a14 100644 --- a/src/relax/op/nn/convolution.cc +++ b/src/relax/op/nn/convolution.cc @@ -29,6 +29,161 @@ namespace tvm { namespace relax { +/* relax.nn.conv1d */ +TVM_REGISTER_NODE_TYPE(Conv1DAttrs); + +Expr conv1d(Expr data, Expr weight, Array strides, Array padding, + Array dilation, int groups, String data_layout, String kernel_layout, + Optional out_layout, DataType out_dtype) { + padding = GetCompletePadding1D(std::move(padding)); + + CHECK_GT(groups, 0) << "The number of groups in convolution is expected to be positive. However, " + "the given number of groups is " + << groups; + CHECK_EQ(strides.size(), 1) + << "The input strides length is expected to be 1. However, the given strides is " << strides; + CHECK_EQ(dilation.size(), 1) + << "The input dilation length is expected to be 1. However, the given dilation is " + << dilation; + return MakeConv(std::move(data), std::move(weight), std::move(strides), + std::move(padding), std::move(dilation), groups, data_layout, + std::move(kernel_layout), out_layout.value_or(data_layout), + out_dtype, /*op_name=*/"relax.nn.conv1d"); +} + +TVM_REGISTER_GLOBAL("relax.op.nn.conv1d").set_body_typed(conv1d); + +StructInfo InferStructInfoConv1d(const Call& call, const BlockBuilder& ctx) { + Array input_sinfo = GetInputTensorStructInfo(call, ctx); + TensorStructInfo data_sinfo = input_sinfo[0]; + TensorStructInfo weight_sinfo = input_sinfo[1]; + + const auto* attrs = call->attrs.as(); + auto [data_layout, data2NCW] = CheckTensorLayout(call, ctx, attrs->data_layout, // + /*tgt_layout=*/"NCW", // + /*tensor_name=*/"data"); + auto [weight_layout, weight2OIW] = CheckTensorLayout(call, ctx, attrs->kernel_layout, // + /*tgt_layout=*/"OIW", // + /*tensor_name=*/"kernel"); + auto [out_layout, out2NCW] = CheckTensorLayout(call, ctx, attrs->out_layout, // + /*tgt_layout=*/"NCW", // + /*tensor_name=*/"output"); + + Optional data_shape = + CheckNdimPerLayoutAndGetShape(call, ctx, data_sinfo, data_layout); + Optional weight_shape = + CheckNdimPerLayoutAndGetShape(call, ctx, weight_sinfo, weight_layout); + + DataType out_dtype = attrs->out_dtype.is_void() + ? InferBinaryArithOpOutDtype(call, ctx, data_sinfo, weight_sinfo) + : attrs->out_dtype; + if (!data_shape.defined() || !weight_shape.defined()) { + return TensorStructInfo(out_dtype, out_layout.ndim()); + } + + Array data_NCW_shape = data2NCW.ForwardShape(data_shape.value()->values); + Array weight_OIW_shape = weight2OIW.ForwardShape(weight_shape.value()->values); + + arith::Analyzer* analyzer = ctx->GetAnalyzer(); + PrimExpr input_channel_data = data_NCW_shape[1]; + PrimExpr input_channel_kernel = weight_OIW_shape[1]; + if (analyzer->CanProve(input_channel_data != input_channel_kernel * attrs->groups)) { + ctx->ReportFatal( + Diagnostic::Error(call) + << "The channel size of the data should equal to the product of input channel size of the " + "weight and the number of groups. However, the data channel size is " + << input_channel_data << " while the weight input channel size and number of groups are " + << input_channel_kernel << " and " << attrs->groups); + } else if (!analyzer->CanProveEqual(input_channel_data, input_channel_kernel * attrs->groups)) { + // Todo(relax-team): Trust the input shape at this moment, and revisit + // this condition with runtime shape check + } + if (analyzer->CanProve(floormod(weight_OIW_shape[0], attrs->groups) != 0)) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Conv1d expects the number of output channels to be divisible by the " + "number of groups. However, the number of output channels is " + << weight_OIW_shape[0] << " while the number of groups is " << attrs->groups); + } else if (!analyzer->CanProveEqual(floormod(weight_OIW_shape[0], attrs->groups), 0)) { + // Todo(relax-team): Trust the input shape at this moment, and revisit + // this condition with runtime shape check + } + + PrimExpr input_w = data_NCW_shape[2]; + PrimExpr kernel_w = weight_OIW_shape[2]; + PrimExpr padding_w = attrs->padding[0] + attrs->padding[1]; + + std::vector out_NCW_shape; + out_NCW_shape.resize(3); + out_NCW_shape[0] = data_NCW_shape[0]; + out_NCW_shape[1] = weight_OIW_shape[0]; + + PrimExpr numerator_w = input_w + padding_w - attrs->dilation[0] * (kernel_w - 1) - 1; + out_NCW_shape[2] = analyzer->Simplify(floordiv(numerator_w, attrs->strides[0]) + 1); + + Array out_shape = out2NCW.BackwardShape(out_NCW_shape); + return TensorStructInfo(ShapeExpr(out_shape), out_dtype); +} + +InferLayoutOutput InferLayoutConv1d(const Call& call, + const Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { + const auto& it = desired_layouts.find("relax.nn.conv1d"); + const auto* attrs = call->attrs.as(); + ICHECK(attrs) << "Invalid Call"; + + LayoutDecision data_layout, weight_layout, output_layout; + ObjectPtr new_attrs = make_object(*attrs); + + if (it != desired_layouts.end()) { + // We have a desired layout for conv1d. + Layout desired_data_layout = (*it).second[0]; + Layout desired_weight_layout = (*it).second[1]; + Layout desired_output_layout = (*it).second.size() == 3 ? (*it).second[2] : (*it).second[0]; + ICHECK_EQ(desired_data_layout.ndim(), desired_data_layout.ndim_primal()) << "Axis swap only"; + ICHECK_EQ(desired_weight_layout.ndim(), desired_weight_layout.ndim_primal()) + << "Axis swap only"; + ICHECK_EQ(desired_output_layout.ndim(), desired_output_layout.ndim_primal()) + << "Axis swap only"; + data_layout = TransposeLike(InitialLayout(3), attrs->data_layout, desired_data_layout); + weight_layout = TransposeLike(InitialLayout(3), attrs->kernel_layout, desired_weight_layout); + output_layout = TransposeLike(InitialLayout(3), attrs->out_layout, desired_output_layout); + new_attrs->data_layout = (*it).second[0]; + new_attrs->kernel_layout = (*it).second[1]; + new_attrs->out_layout = (*it).second.size() == 3 ? (*it).second[2] : (*it).second[0]; + } else { + // We don't have a desired layout for conv1d. + // We can just propagate the layout from the input. + data_layout = GetLayoutDecision(var_layout_map, call->args[0]); + weight_layout = GetLayoutDecision(var_layout_map, call->args[1]); + output_layout = data_layout; + new_attrs->data_layout = + TransposeLike(attrs->data_layout, InitialLayout(3), data_layout->layout).name(); + new_attrs->kernel_layout = + TransposeLike(attrs->kernel_layout, InitialLayout(3), weight_layout->layout).name(); + new_attrs->out_layout = + TransposeLike(attrs->out_layout, InitialLayout(3), output_layout->layout).name(); + } + return InferLayoutOutput({data_layout, weight_layout}, {output_layout}, Attrs(new_attrs)); +} + +Call InferMixedPrecisionConv1d(const Call& call, const DataType& out_dtype) { + const auto* conv1d_attrs = call->attrs.as(); + return Downcast(conv1d(call->args[0], call->args[1], conv1d_attrs->strides, + conv1d_attrs->padding, conv1d_attrs->dilation, conv1d_attrs->groups, + conv1d_attrs->data_layout, conv1d_attrs->kernel_layout, + conv1d_attrs->out_layout, out_dtype)); +} + +TVM_REGISTER_OP("relax.nn.conv1d") + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("weight", "Tensor", "The weight tensor.") + .set_attrs_type() + .set_attr("FInferStructInfo", InferStructInfoConv1d) + .set_attr("FRelaxInferLayout", InferLayoutConv1d) + .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kAlways) + .set_attr("FInferMixedPrecision", InferMixedPrecisionConv1d); + /* relax.nn.conv2d */ TVM_REGISTER_NODE_TYPE(Conv2DAttrs); diff --git a/src/relax/op/nn/convolution.h b/src/relax/op/nn/convolution.h index 7093c6a4d979..833e730ee949 100644 --- a/src/relax/op/nn/convolution.h +++ b/src/relax/op/nn/convolution.h @@ -52,6 +52,11 @@ inline Expr MakeConv(Expr data, Expr weight, Array strides, Array strides, Array padding, + Array dilation, int groups, String data_layout, String kernel_layout, + Optional out_layout, DataType out_dtype); + /*! \brief 2D convolution */ Expr conv2d(Expr data, Expr weight, Array strides, Array padding, Array dilation, int groups, String data_layout, String kernel_layout, diff --git a/src/relax/op/op_common.h b/src/relax/op/op_common.h index bd5f2cd4d55e..616dded39e52 100644 --- a/src/relax/op/op_common.h +++ b/src/relax/op/op_common.h @@ -240,6 +240,26 @@ inline Array ConvertIntImmToInt64(const Array& int_imms) { /************ Utilities for NN operators ************/ +/*! + * \brief Complete the padding to a 2-length array. + * - If the padding length is 1, the same padding is used on all left/right sides + * - If the padding length is 2, padding is in the order of (left, right) + * \param padding The given padding to be completed + * \return The completed padding. + * \throws Throws error if the input padding length is neither 1 or 2. + */ +inline Array GetCompletePadding1D(Array padding) { + if (padding.size() == 1) { + return {padding[0], padding[0]}; + } else if (padding.size() == 2) { + return padding; + } + LOG(FATAL) << "The input padding length is expected to be either 1 or 2. However, the given " + "padding is " + << padding; + throw; +} + /*! * \brief Complete the padding to a 4-length array. * - If the padding length is 1, the same padding is used on all top/left/bottom/right sides diff --git a/tests/python/relax/test_op_nn_convolution.py b/tests/python/relax/test_op_nn_convolution.py index 334f6977f7f3..d1d604429e93 100644 --- a/tests/python/relax/test_op_nn_convolution.py +++ b/tests/python/relax/test_op_nn_convolution.py @@ -23,7 +23,13 @@ from tvm.script import relax as R -def test_op_correctness(): +def test_conv1d_op_correctness(): + x = relax.Var("x", R.Tensor((2, 3, 28), "float32")) + w = relax.Var("w", R.Tensor((4, 3, 3), "float32")) + assert relax.op.nn.conv1d(x, w).op == Op.get("relax.nn.conv1d") + + +def test_conv2d_op_correctness(): x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) w = relax.Var("w", R.Tensor((4, 3, 3, 3), "float32")) assert relax.op.nn.conv2d(x, w).op == Op.get("relax.nn.conv2d") @@ -35,6 +41,376 @@ def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: r tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo) +def test_conv1d_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 28), "float32")) + x1 = relax.Var("x", R.Tensor((2, 28, 3), "float32")) + x2 = relax.Var("x", R.Tensor("float32", ndim=3)) + x3 = relax.Var("x", R.Tensor("float32")) + x4 = relax.Var("x", R.Tensor()) + x5 = relax.Var("x", R.Tensor((2, 4, 28, 16), "float32")) + w0 = relax.Var("w", R.Tensor((4, 3, 3), "float32")) + w1 = relax.Var("w", R.Tensor((3, 4, 3), "float32")) + w2 = relax.Var("w", R.Tensor("float32", ndim=3)) + w3 = relax.Var("w", R.Tensor("float32")) + w4 = relax.Var("w", R.Tensor((48, 4, 3, 16), "float32")) + + _check_inference(bb, relax.op.nn.conv1d(x0, w0), relax.TensorStructInfo((2, 4, 26), "float32")) + _check_inference( + bb, + relax.op.nn.conv1d(x0, w0, out_dtype="float16"), + relax.TensorStructInfo((2, 4, 26), "float16"), + ) + _check_inference( + bb, relax.op.nn.conv1d(x0, w0, padding=1), relax.TensorStructInfo((2, 4, 28), "float32") + ) + _check_inference( + bb, + relax.op.nn.conv1d(x0, w0, padding=[1, 3]), + relax.TensorStructInfo((2, 4, 30), "float32"), + ) + _check_inference( + bb, + relax.op.nn.conv1d(x0, w0, strides=2), + relax.TensorStructInfo((2, 4, 13), "float32"), + ) + _check_inference( + bb, + relax.op.nn.conv1d(x0, w0, strides=(2,)), + relax.TensorStructInfo((2, 4, 13), "float32"), + ) + _check_inference( + bb, + relax.op.nn.conv1d(x0, w0, dilation=2), + relax.TensorStructInfo((2, 4, 24), "float32"), + ) + _check_inference( + bb, + relax.op.nn.conv1d(x0, w0, dilation=(2,)), + relax.TensorStructInfo((2, 4, 24), "float32"), + ) + _check_inference( + bb, + relax.op.nn.conv1d(x1, w0, data_layout="NWC"), + relax.TensorStructInfo((2, 26, 4), "float32"), + ) + _check_inference( + bb, + relax.op.nn.conv1d(x0, w0, out_layout="NWC"), + relax.TensorStructInfo((2, 26, 4), "float32"), + ) + _check_inference( + bb, + relax.op.nn.conv1d(x0, w1, kernel_layout="IOW"), + relax.TensorStructInfo((2, 4, 26), "float32"), + ) + _check_inference( + bb, + relax.op.nn.conv1d( + x5, w4, data_layout="NCW16c", kernel_layout="OIW16i", out_layout="NWC16c" + ), + relax.TensorStructInfo((2, 26, 3, 16), "float32"), + ) + _check_inference( + bb, relax.op.nn.conv1d(x2, w0), relax.TensorStructInfo(dtype="float32", ndim=3) + ) + _check_inference( + bb, relax.op.nn.conv1d(x3, w0), relax.TensorStructInfo(dtype="float32", ndim=3) + ) + _check_inference( + bb, relax.op.nn.conv1d(x0, w2), relax.TensorStructInfo(dtype="float32", ndim=3) + ) + _check_inference( + bb, relax.op.nn.conv1d(x0, w3), relax.TensorStructInfo(dtype="float32", ndim=3) + ) + _check_inference(bb, relax.op.nn.conv1d(x4, w0), relax.TensorStructInfo(dtype="", ndim=3)) + + +def test_conv1d_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + c = tir.Var("c", "int64") + c16 = tir.Var("c16", "int64") + iw = tir.Var("iw", "int64") + ki = tir.Var("ki", "int64") + ko = tir.Var("ko", "int64") + kw = tir.Var("kw", "int64") + x0 = relax.Var("x", R.Tensor((n, c, iw), "float32")) + x1 = relax.Var("x", R.Tensor((n, c, iw, c16), "float32")) + w0 = relax.Var("w", R.Tensor((ko, ki, kw), "float32")) + w1 = relax.Var("w", R.Tensor((ko, c, kw), "float32")) + w2 = relax.Var("w", R.Tensor((ko, c, kw, c16), "float32")) + + _check_inference( + bb, + relax.op.nn.conv1d(x0, w0), + relax.TensorStructInfo((n, ko, iw + 1 - kw), "float32"), + ) + _check_inference( + bb, + relax.op.nn.conv1d(x0, w1), + relax.TensorStructInfo((n, ko, iw + 1 - kw), "float32"), + ) + _check_inference( + bb, + relax.op.nn.conv1d(x1, w2, data_layout="NCW16c", kernel_layout="OIW16i", out_layout="NCW"), + relax.TensorStructInfo((n, ko, iw + 1 - kw), "float32"), + ) + _check_inference( + bb, + relax.op.nn.conv1d(x0, w0, strides=2, padding=1, dilation=2), + relax.TensorStructInfo( + (n, ko, tvm.tir.floordiv(iw + 3, 2) + 1 - kw), + "float32", + ), + ) + + +def test_conv1d_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo(ndim=3)) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=4)) + s2 = relax.Var("s", relax.ShapeStructInfo(ndim=3)) + s3 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(s3, "float32")) + w = relax.Var("w", relax.TensorStructInfo(s2, "float32")) + + _check_inference(bb, relax.op.nn.conv1d(x0, w), relax.TensorStructInfo(dtype="float32", ndim=3)) + _check_inference( + bb, + relax.op.nn.conv1d(x1, w, data_layout="NCW16c"), + relax.TensorStructInfo(dtype="float32", ndim=4), + ) + _check_inference( + bb, + relax.op.nn.conv1d(x0, w, out_layout="NCW16c"), + relax.TensorStructInfo(dtype="float32", ndim=4), + ) + _check_inference( + bb, + relax.op.nn.conv1d(x2, w), + relax.TensorStructInfo(dtype="float32", ndim=3), + ) + + +def test_conv1d_infer_struct_info_groups(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 128, 28), "float32")) + x1 = relax.Var("x", R.Tensor((2, 8, 28, 16), "float32")) + w0 = relax.Var("w", R.Tensor((48, 16, 3), "float32")) + w1 = relax.Var("w", R.Tensor((48, 2, 3, 8), "float32")) + + _check_inference( + bb, relax.op.nn.conv1d(x0, w0, groups=8), relax.TensorStructInfo((2, 48, 26), "float32") + ) + _check_inference( + bb, + relax.op.nn.conv1d(x0, w1, kernel_layout="OIW8i", groups=8), + relax.TensorStructInfo((2, 48, 26), "float32"), + ) + _check_inference( + bb, + relax.op.nn.conv1d(x1, w0, data_layout="NCW16c", groups=8), + relax.TensorStructInfo((2, 3, 26, 16), "float32"), + ) + + +def test_conv1d_infer_struct_info_symbolic_groups(): + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + ic = tir.Var("c", "int64") + oc = tir.Var("oc", "int64") + x = relax.Var("x", R.Tensor((n, ic * 4, 28), "float32")) + w0 = relax.Var("w", R.Tensor((oc * 4, ic, 3), "float32")) + w1 = relax.Var("w", R.Tensor((oc, ic, 3), "float32")) + + _check_inference( + bb, + relax.op.nn.conv1d(x, w0, groups=4), + relax.TensorStructInfo((n, oc * 4, 26), "float32"), + ) + _check_inference( + bb, relax.op.nn.conv1d(x, w1, groups=4), relax.TensorStructInfo((n, oc, 26), "float32") + ) + + +def test_conv1d_infer_struct_info_input_channel_group_incompatible(): + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + ic = tir.Var("c", "int64") + oc = tir.Var("oc", "int64") + x0 = relax.Var("x", R.Tensor((2, 128, 28), "float32")) + w0 = relax.Var("w", R.Tensor((48, 20, 3), "float32")) + x1 = relax.Var("x", R.Tensor((n, ic * 6, 28), "float32")) + w1 = relax.Var("w", R.Tensor((oc, ic - 1, 3), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv1d(x0, w0, groups=6)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv1d(x1, w1, groups=6)) + + +def test_conv1d_infer_struct_info_output_channel_group_incompatible(): + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + ic = tir.Var("c", "int64") + oc = tir.Var("oc", "int64") + x0 = relax.Var("x", R.Tensor((2, 120, 28), "float32")) + w0 = relax.Var("w", R.Tensor((128, 20, 3), "float32")) + x1 = relax.Var("x", R.Tensor((n, ic * 6, 28), "float32")) + w1 = relax.Var("w", R.Tensor((oc * 6 + 4, ic * 6, 3), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv1d(x0, w0, groups=6)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv1d(x1, w1, groups=6)) + + +def test_conv1d_non_positive_group(): + x = relax.Var("x", R.Tensor((2, 128, 28), "float32")) + w = relax.Var("w", R.Tensor((48, 16, 3), "float32")) + + with pytest.raises(TVMError): + relax.op.nn.conv1d(x, w, groups=0) + with pytest.raises(TVMError): + relax.op.nn.conv1d(x, w, groups=-2) + + +def test_conv1d_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 28), "float16")) + w0 = relax.Var("w", R.Tensor((4, 3, 3), "float16")) + x1 = relax.Var("x", R.Tensor((2, 3, 28), "float64")) + w1 = relax.Var("w", R.Tensor((4, 3, 3), "float64")) + x2 = relax.Var("x", R.Tensor((2, 3, 28), "int8")) + w2 = relax.Var("w", R.Tensor((4, 3, 3), "int8")) + x3 = relax.Var("x", R.Tensor((2, 3, 28), "int32")) + w3 = relax.Var("w", R.Tensor((4, 3, 3), "int32")) + + _check_inference(bb, relax.op.nn.conv1d(x0, w0), relax.TensorStructInfo((2, 4, 26), "float16")) + _check_inference(bb, relax.op.nn.conv1d(x1, w1), relax.TensorStructInfo((2, 4, 26), "float64")) + _check_inference(bb, relax.op.nn.conv1d(x2, w2), relax.TensorStructInfo((2, 4, 26), "int8")) + _check_inference(bb, relax.op.nn.conv1d(x3, w3), relax.TensorStructInfo((2, 4, 26), "int32")) + + +def test_conv1d_infer_struct_info_mixed_precision(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 28), "float16")) + w0 = relax.Var("w", R.Tensor((4, 3, 3), "float16")) + x1 = relax.Var("x", R.Tensor((2, 3, 28), "int8")) + w1 = relax.Var("w", R.Tensor((4, 3, 3), "int8")) + x2 = relax.Var("x", R.Tensor((2, 3, 28))) + w2 = relax.Var("w", R.Tensor((4, 3, 3))) + + _check_inference( + bb, + relax.op.nn.conv1d(x0, w0, out_dtype="float32"), + relax.TensorStructInfo((2, 4, 26), "float32"), + ) + _check_inference( + bb, + relax.op.nn.conv1d(x1, w1, out_dtype="int32"), + relax.TensorStructInfo((2, 4, 26), "int32"), + ) + _check_inference( + bb, + relax.op.nn.conv1d(x2, w2, out_dtype="float32"), + relax.TensorStructInfo((2, 4, 26), "float32"), + ) + + +def test_conv1d_unequal_input_channel(): + bb = relax.BlockBuilder() + ic = tir.Var("ic", "int64") + x0 = relax.Var("x", R.Tensor([2, 3, 28], "float32")) + w0 = relax.Var("w", R.Tensor([3, 4, 3], "float32")) + x1 = relax.Var("x", R.Tensor([2, ic, 28], "float32")) + w1 = relax.Var("w", R.Tensor([4, ic + 2, 3], "float32")) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv1d(x0, w0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv1d(x1, w1)) + + +def test_conv1d_stride_padding_dilation_int64(): + x = relax.Var("x", R.Tensor((2, 3, 28), "float32")) + w = relax.Var("w", R.Tensor((4, 3, 3), "float32")) + conv1d = relax.op.nn.conv1d(x, w, strides=(1,), padding=(1, 1), dilation=(1,)) + + assert conv1d.attrs.strides[0].dtype == "int64" + assert conv1d.attrs.padding[0].dtype == "int64" + assert conv1d.attrs.padding[1].dtype == "int64" + assert conv1d.attrs.dilation[0].dtype == "int64" + + +def test_conv1d_wrong_strides_padding_dilation_length(): + x = relax.Var("x", R.Tensor((2, 3, 28), "float32")) + w = relax.Var("w", R.Tensor((4, 3, 3), "float32")) + with pytest.raises(TVMError): + relax.op.nn.conv1d(x, w, strides=(1, 2)) + with pytest.raises(TVMError): + relax.op.nn.conv1d(x, w, padding=(1, 2, 3)) + with pytest.raises(TVMError): + relax.op.nn.conv1d(x, w, dilation=(1, 2)) + + +def test_conv1d_infer_struct_info_wrong_layout_string(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 28), "float32")) + w = relax.Var("w", R.Tensor((4, 3, 3), "float32")) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv1d(x, w, data_layout="OIW")) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv1d(x, w, kernel_layout="NWC")) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv1d(x, w, out_layout="OWI")) + + +def test_conv1d_dtype_mismatch(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 28), "float32")) + w = relax.Var("w", R.Tensor((4, 3, 3), "int8")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv1d(x, w)) + + +def test_conv1d_wrong_input_ndim(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 28), "float32")) + x1 = relax.Var("x", R.Tensor((2, 3, 28, 3), "float32")) + x2 = relax.Var("x", R.Tensor("float32", ndim=2)) + w0 = relax.Var("w", R.Tensor((4, 3, 3), "float32")) + w1 = relax.Var("w", R.Tensor((4, 3, 6, 3), "float32")) + w2 = relax.Var("w", R.Tensor("float32", ndim=5)) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv1d(x0, w1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv1d(x0, w1, data_layout="NCW16c")) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv1d(x0, w2)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv1d(x1, w0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv1d(x2, w0)) + + +def test_conv1d_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 28), "float32")) + x1 = relax.Var("x", relax.ShapeStructInfo((2, 3, 28))) + w0 = relax.Var("w", R.Tensor((4, 3, 3), "float32")) + w1 = relax.Var("w", relax.FuncStructInfo([], R.Tensor((4, 3, 3), "float32"))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv1d(x0, w1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv1d(x1, w0)) + + def test_conv2d_infer_struct_info(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) diff --git a/tests/python/relax/test_tvmscript_parser_op_nn.py b/tests/python/relax/test_tvmscript_parser_op_nn.py index cfb454a578f7..5e569cc4f44c 100644 --- a/tests/python/relax/test_tvmscript_parser_op_nn.py +++ b/tests/python/relax/test_tvmscript_parser_op_nn.py @@ -35,6 +35,24 @@ def _check( tvm.ir.assert_structural_equal(parsed, expect) +def test_conv1d(): + @R.function + def foo( + x: R.Tensor((2, 3, 228), "float32"), w: R.Tensor((16, 3, 5), "float32") + ) -> R.Tensor((2, 16, 224), "float16"): + gv: R.Tensor((2, 16, 224), "float16") = R.nn.conv1d(x, w, out_dtype="float16") + return gv + + x = relax.Var("x", R.Tensor([2, 3, 228], "float32")) + w = relax.Var("w", R.Tensor([16, 3, 5], "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x, w]): + gv = bb.emit(relax.op.nn.conv1d(x, w, out_dtype="float16")) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + def test_conv2d(): @R.function def foo( From 928b3968184aa12856a0d1ada98527aacb7d9ed9 Mon Sep 17 00:00:00 2001 From: Lesheng Jin Date: Thu, 23 Mar 2023 13:20:50 -0700 Subject: [PATCH 2/5] legalization for conv1d --- python/tvm/relax/transform/legalize_ops/nn.py | 40 ++++ .../relax/test_transform_legalize_ops_nn.py | 177 ++++++++++++++++++ 2 files changed, 217 insertions(+) diff --git a/python/tvm/relax/transform/legalize_ops/nn.py b/python/tvm/relax/transform/legalize_ops/nn.py index bfc054453607..889e6e09417b 100644 --- a/python/tvm/relax/transform/legalize_ops/nn.py +++ b/python/tvm/relax/transform/legalize_ops/nn.py @@ -24,6 +24,46 @@ from .common import register_legalize, _call_topi_without_attr +@register_legalize("relax.nn.conv1d") +def _nn_conv1d(bb: BlockBuilder, call: Call) -> Expr: + if call.attrs.out_layout != call.attrs.data_layout: + logging.info( + "TOPI conv1d does not support different input-output " + "layouts, and thus cannot be legalized by TOPI" + ) + return call + if len(call.attrs.data_layout) != 3 or len(call.attrs.kernel_layout) != 3: + logging.info( + "Conv1D where data layout or kernel layout have channel chunk " + "cannot be legalized by TOPI at this moment." + ) + return call + if call.attrs.groups != 1: + data_layout = tir.layout(call.attrs.data_layout) + kernel_layout = tir.layout(call.attrs.kernel_layout) + ic = call.args[0].struct_info.shape.values[data_layout.index_of("C")] + oc = call.args[1].struct_info.shape.values[kernel_layout.index_of("O")] + if not isinstance(ic, tir.IntImm) or not isinstance(oc, tir.IntImm): + logging.info( + "Conv1D where number of groups is more than one and input or output " + "channel size is symbolic cannot be legalized by TOPI at this moment." + ) + return call + + return bb.call_te( + topi.nn.conv1d, + data=call.args[0], + kernel=call.args[1], + strides=call.attrs.strides, + padding=call.attrs.padding, + dilation=call.attrs.dilation, + data_layout=call.attrs.data_layout, + kernel_layout=call.attrs.kernel_layout, + out_dtype=call.attrs.out_dtype if call.attrs.out_dtype != "" else None, + primfunc_name_hint="conv1d", + ) + + @register_legalize("relax.nn.conv2d") def _nn_conv2d(bb: BlockBuilder, call: Call) -> Expr: if call.attrs.out_layout != call.attrs.data_layout: diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py b/tests/python/relax/test_transform_legalize_ops_nn.py index a1fe266d68ea..242b78926cfd 100644 --- a/tests/python/relax/test_transform_legalize_ops_nn.py +++ b/tests/python/relax/test_transform_legalize_ops_nn.py @@ -25,6 +25,183 @@ ##################### Neural network ##################### +def test_conv1d(): + # fmt: off + @tvm.script.ir_module + class Conv1d: + @R.function + def main(x: R.Tensor((2, 128, 28), "float32"), w: R.Tensor((64, 16, 3), "float32")) -> R.Tensor((2, 64, 13), "float32"): + gv: R.Tensor((2, 4, 13), "float32") = R.nn.conv1d(x, w, strides=(2,), padding=(1,), dilation=(2,), groups=8) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 128, 28), dtype="float32"), w: R.Tensor((64, 16, 3), dtype="float32")) -> R.Tensor((2, 64, 13), dtype="float32"): + gv = R.call_tir(Expected.conv1d, (x, w), out_sinfo=R.Tensor((2, 64, 13), dtype="float32")) + return gv + + @T.prim_func + def conv1d(rxplaceholder: T.Buffer((T.int64(2), T.int64(128), T.int64(28)), "float32"), rxplaceholder_1: T.Buffer((T.int64(64), T.int64(16), T.int64(3)), "float32"), conv1d_ncw: T.Buffer((T.int64(2), T.int64(64), T.int64(13)), "float32")): + T.func_attr({"tir.noalias": True}) + pad_temp = T.alloc_buffer((T.int64(2), T.int64(128), T.int64(30))) + for i0, i1, i2 in T.grid(T.int64(2), T.int64(128), T.int64(30)): + with T.block("pad_temp"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(rxplaceholder[v_i0, v_i1, v_i2 - T.int64(1)]) + T.writes(pad_temp[v_i0, v_i1, v_i2]) + pad_temp[v_i0, v_i1, v_i2] = T.if_then_else(T.int64(1) <= v_i2 and v_i2 < T.int64(29), rxplaceholder[v_i0, v_i1, v_i2 - T.int64(1)], T.float32(0)) + for nn, ff, yy, rc, ry in T.grid(T.int64(2), T.int64(64), T.int64(13), T.int64(128), T.int64(3)): + with T.block("conv1d_ncw"): + v_nn, v_ff, v_yy, v_rc, v_ry = T.axis.remap("SSSRR", [nn, ff, yy, rc, ry]) + T.reads(pad_temp[v_nn, v_rc, v_yy * T.int64(2) + v_ry * T.int64(2)], rxplaceholder_1[v_ff, v_rc, v_ry]) + T.writes(conv1d_ncw[v_nn, v_ff, v_yy]) + with T.init(): + conv1d_ncw[v_nn, v_ff, v_yy] = T.float32(0) + conv1d_ncw[v_nn, v_ff, v_yy] = conv1d_ncw[v_nn, v_ff, v_yy] + pad_temp[v_nn, v_rc, v_yy * T.int64(2) + v_ry * T.int64(2)] * rxplaceholder_1[v_ff, v_rc, v_ry] + # fmt: on + + mod = LegalizeOps()(Conv1d) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_conv1d_with_out_dtype(): + # fmt: off + @tvm.script.ir_module + class Conv1d: + @R.function + def main(x: R.Tensor((2, 3, 28), "float32"), w: R.Tensor((4, 3, 3), "float32")) -> R.Tensor((2, 4, 26), "float16"): + gv: R.Tensor((2, 4, 26), "float16") = R.nn.conv1d(x, w, out_dtype="float16") + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3, 28), dtype="float32"), w: R.Tensor((4, 3, 3), dtype="float32")) -> R.Tensor((2, 4, 26), dtype="float16"): + gv = R.call_tir(Expected.conv1d, (x, w), out_sinfo=R.Tensor((2, 4, 26), dtype="float16")) + return gv + + @T.prim_func + def conv1d(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(28)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(3)), "float32"), conv1d_ncw: T.Buffer((T.int64(2), T.int64(4), T.int64(26)), "float16")): + T.func_attr({"tir.noalias": True}) + # with T.block("root"): + pad_temp = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(28))) + for i0, i1, i2 in T.grid(T.int64(2), T.int64(3), T.int64(28)): + with T.block("pad_temp"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(rxplaceholder[v_i0, v_i1, v_i2]) + T.writes(pad_temp[v_i0, v_i1, v_i2]) + pad_temp[v_i0, v_i1, v_i2] = rxplaceholder[v_i0, v_i1, v_i2] + for nn, ff, yy, rc, ry in T.grid(T.int64(2), T.int64(4), T.int64(26), T.int64(3), T.int64(3)): + with T.block("conv1d_ncw"): + v_nn, v_ff, v_yy, v_rc, v_ry = T.axis.remap("SSSRR", [nn, ff, yy, rc, ry]) + T.reads(pad_temp[v_nn, v_rc, v_yy + v_ry], rxplaceholder_1[v_ff, v_rc, v_ry]) + T.writes(conv1d_ncw[v_nn, v_ff, v_yy]) + with T.init(): + conv1d_ncw[v_nn, v_ff, v_yy] = T.float16(0) + conv1d_ncw[v_nn, v_ff, v_yy] = conv1d_ncw[v_nn, v_ff, v_yy] + T.Cast("float16", pad_temp[v_nn, v_rc, v_yy + v_ry]) * T.Cast("float16", rxplaceholder_1[v_ff, v_rc, v_ry]) + # fmt: on + + mod = LegalizeOps()(Conv1d) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_conv1d_nwc(): + # fmt: off + @tvm.script.ir_module + class Conv1d: + @R.function + def main(x: R.Tensor((2, 28, 128), "float32"), w: R.Tensor((64, 128, 3), "float32")) -> R.Tensor((2, 26, 64), "float32"): + gv: R.Tensor((2, 26, 64), "float32") = R.nn.conv1d(x, w, data_layout="NWC") + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 28, 128), dtype="float32"), w: R.Tensor((64, 128, 3), dtype="float32")) -> R.Tensor((2, 26, 64), dtype="float32"): + gv = R.call_tir(Expected.conv1d, (x, w), out_sinfo=R.Tensor((2, 26, 64), dtype="float32")) + return gv + + @T.prim_func + def conv1d(rxplaceholder: T.Buffer((T.int64(2), T.int64(28), T.int64(128)), "float32"), rxplaceholder_1: T.Buffer((T.int64(64), T.int64(128), T.int64(3)), "float32"), conv1d_nwc: T.Buffer((T.int64(2), T.int64(26), T.int64(64)), "float32")): + T.func_attr({"tir.noalias": True}) + # with T.block("root"): + pad_temp = T.alloc_buffer((T.int64(2), T.int64(28), T.int64(128))) + for i0, i1, i2 in T.grid(T.int64(2), T.int64(28), T.int64(128)): + with T.block("pad_temp"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(rxplaceholder[v_i0, v_i1, v_i2]) + T.writes(pad_temp[v_i0, v_i1, v_i2]) + pad_temp[v_i0, v_i1, v_i2] = rxplaceholder[v_i0, v_i1, v_i2] + for nn, yy, ff, ry, rc in T.grid(T.int64(2), T.int64(26), T.int64(64), T.int64(3), T.int64(128)): + with T.block("conv1d_nwc"): + v_nn, v_yy, v_ff, v_ry, v_rc = T.axis.remap("SSSRR", [nn, yy, ff, ry, rc]) + T.reads(pad_temp[v_nn, v_yy + v_ry, v_rc], rxplaceholder_1[v_ff, v_rc, v_ry]) + T.writes(conv1d_nwc[v_nn, v_yy, v_ff]) + with T.init(): + conv1d_nwc[v_nn, v_yy, v_ff] = T.float32(0) + conv1d_nwc[v_nn, v_yy, v_ff] = conv1d_nwc[v_nn, v_yy, v_ff] + pad_temp[v_nn, v_yy + v_ry, v_rc] * rxplaceholder_1[v_ff, v_rc, v_ry] + # fmt: on + + mod = LegalizeOps()(Conv1d) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_conv1d_symbolic(): + # fmt: off + @tvm.script.ir_module + class Conv1d: + @R.function + def main(x: R.Tensor(("n", "c", "w"), "float32"), kernel: R.Tensor(("f", "c", "kw"), "float32")) -> R.Tensor(("n", "f", "w - kw + 1"), "float32"): + n = T.int64() + w = T.int64() + f = T.int64() + kw = T.int64() + gv: R.Tensor((n, f, w - kw + 1), "float32") = R.nn.conv1d(x, kernel) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("n", "c", "w"), dtype="float32"), kernel: R.Tensor(("f", "c", "kw"), dtype="float32")) -> R.Tensor(("n", "f", "w - kw + 1"), dtype="float32"): + n = T.int64() + f = T.int64() + w = T.int64() + kw = T.int64() + c = T.int64() + gv = R.call_tir(Expected.conv1d, (x, kernel), out_sinfo=R.Tensor((n, f, w - kw + 1), dtype="float32")) + return gv + + @T.prim_func + def conv1d(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_conv1d_ncw: T.handle): + T.func_attr({"tir.noalias": True}) + n, c, w = T.int64(), T.int64(), T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, (n, c, w)) + f, kw = T.int64(), T.int64() + rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, (f, c, kw)) + conv1d_ncw = T.match_buffer(var_conv1d_ncw, (n, f, w - kw + T.int64(1))) + # with T.block("root"): + pad_temp = T.alloc_buffer((n, c, w)) + for i0, i1, i2 in T.grid(n, c, w): + with T.block("pad_temp"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(rxplaceholder[v_i0, v_i1, v_i2]) + T.writes(pad_temp[v_i0, v_i1, v_i2]) + pad_temp[v_i0, v_i1, v_i2] = rxplaceholder[v_i0, v_i1, v_i2] + for nn, ff, yy, rc, ry in T.grid(n, f, w + T.int64(1) - kw, c, kw): + with T.block("conv1d_ncw"): + v_nn, v_ff, v_yy, v_rc, v_ry = T.axis.remap("SSSRR", [nn, ff, yy, rc, ry]) + T.reads(pad_temp[v_nn, v_rc, v_yy + v_ry], rxplaceholder_1[v_ff, v_rc, v_ry]) + T.writes(conv1d_ncw[v_nn, v_ff, v_yy]) + with T.init(): + conv1d_ncw[v_nn, v_ff, v_yy] = T.float32(0) + conv1d_ncw[v_nn, v_ff, v_yy] = conv1d_ncw[v_nn, v_ff, v_yy] + pad_temp[v_nn, v_rc, v_yy + v_ry] * rxplaceholder_1[v_ff, v_rc, v_ry] + # fmt: on + + mod = LegalizeOps()(Conv1d) + tvm.ir.assert_structural_equal(mod, Expected) + + def test_conv2d(): # fmt: off @tvm.script.ir_module From 70aa5056ca5f5560776ca85d5bc8fb60b1b44d1f Mon Sep 17 00:00:00 2001 From: Lesheng Jin Date: Thu, 23 Mar 2023 13:44:39 -0700 Subject: [PATCH 3/5] fx translator for conv1d --- .../tvm/relax/frontend/torch/fx_translator.py | 29 ++++++ tests/python/relax/test_frontend_from_fx.py | 88 ++++++++++++++++++- 2 files changed, 116 insertions(+), 1 deletion(-) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index ef6793cc6712..c65e94d6916e 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -637,6 +637,34 @@ def _linear(self, node: fx.node.Node) -> relax.Var: bias = None if module.bias is None else self.params[module.bias] return self.block_builder.emit(relax.op.linear(x, weight, bias, "float32")) + def _conv1d(self, node: fx.node.Node) -> relax.Var: + x = self.env[node.args[0]] + module = self.named_modules[node.target] + weight = self.params[module.weight] + + conv1d = self.block_builder.emit( + relax.op.nn.conv1d( + x, + weight, + strides=module.stride, + padding=module.padding, + dilation=module.dilation, + groups=module.groups, + data_layout="NCW", + kernel_layout="OIW", + out_dtype="float32", + ) + ) + + if module.bias is None: + return conv1d + + bias = self.params[module.bias] + assert len(self.shape_of(bias)) == 1 + bias = relax.op.reshape(bias, (1, -1, 1)) + + return self.block_builder.emit(relax.op.add(conv1d, bias)) + def _conv2d(self, node: fx.node.Node) -> relax.Var: x = self.env[node.args[0]] module = self.named_modules[node.target] @@ -1001,6 +1029,7 @@ def create_convert_map(self): self.convert_map: Dict[Union[nn.Module, str], Callable[[fx.node.Node], relax.Var]] = { # call_module nn.Linear: self._linear, + nn.Conv1d: self._conv1d, nn.Conv2d: self._conv2d, nn.MaxPool2d: self._max_pool2d, nn.AdaptiveAvgPool2d: self._adaptive_avg_pool2d(is_module=True), diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index d201cb111c66..9e07ff7b59f7 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -37,7 +37,93 @@ def verify_model(torch_model, input_info, binding, expected): @tvm.testing.requires_gpu -def test_conv(): +def test_conv1d(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + class Conv1D1(Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv1d(3, 6, 7, bias=True) + + def forward(self, input): + return self.conv(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10), dtype="float32"), + w1: R.Tensor((6, 3, 7), dtype="float32"), + w2: R.Tensor((6,), dtype="float32"), + ) -> R.Tensor((1, 6, 4), dtype="float32"): + # block 0 + with R.dataflow(): + lv1: R.Tensor((1, 6, 4), dtype="float32") = R.nn.conv1d( + input_1, + w1, + strides=[1], + padding=[0, 0], + dilation=[1], + data_layout="NCW", + kernel_layout="OIW", + out_layout="NCW", + out_dtype="float32", + ) + lv2: R.Tensor((1, 6, 1)) = R.reshape(w2, [1, 6, 1]) + lv3: R.Tensor((1, 6, 4), dtype="float32") = R.add(lv1, lv2) + gv: R.Tensor((1, 6, 4), dtype="float32") = lv3 + R.output(gv) + return gv + + class Conv1D2(Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv1d(3, 6, 7, bias=False) + + def forward(self, input): + return self.conv(input) + + @tvm.script.ir_module + class expected2: + @R.function + def main( + input_1: R.Tensor((1, 3, 10), dtype="float32"), + w1: R.Tensor((6, 3, 7), dtype="float32"), + ) -> R.Tensor((1, 6, 4), dtype="float32"): + # block 0 + with R.dataflow(): + lv1: R.Tensor((1, 6, 4), dtype="float32") = R.nn.conv1d( + input_1, + w1, + strides=[1], + padding=[0, 0], + dilation=[1], + data_layout="NCW", + kernel_layout="OIW", + out_layout="NCW", + out_dtype="float32", + ) + gv: R.Tensor((1, 6, 4), dtype="float32") = lv1 + R.output(gv) + return gv + + input_info = [([1, 3, 10], "float32")] + + model = Conv1D1() + binding = {"w1": model.conv.weight.numpy(), "w2": model.conv.bias.numpy()} + verify_model(model, input_info, binding, expected1) + + model = Conv1D2() + binding = {"w1": model.conv.weight.numpy()} + verify_model(model, input_info, binding, expected2) + + +@tvm.testing.requires_gpu +def test_conv2d(): import torch from torch.nn import Module From 33864f79c38a16a4b199ab03b71a06bf1528ed7e Mon Sep 17 00:00:00 2001 From: Lesheng Jin Date: Fri, 24 Mar 2023 00:22:22 -0700 Subject: [PATCH 4/5] remove white spce --- tests/python/relax/test_transform_legalize_ops_nn.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py b/tests/python/relax/test_transform_legalize_ops_nn.py index 242b78926cfd..e944b8d76ebe 100644 --- a/tests/python/relax/test_transform_legalize_ops_nn.py +++ b/tests/python/relax/test_transform_legalize_ops_nn.py @@ -40,7 +40,7 @@ class Expected: def main(x: R.Tensor((2, 128, 28), dtype="float32"), w: R.Tensor((64, 16, 3), dtype="float32")) -> R.Tensor((2, 64, 13), dtype="float32"): gv = R.call_tir(Expected.conv1d, (x, w), out_sinfo=R.Tensor((2, 64, 13), dtype="float32")) return gv - + @T.prim_func def conv1d(rxplaceholder: T.Buffer((T.int64(2), T.int64(128), T.int64(28)), "float32"), rxplaceholder_1: T.Buffer((T.int64(64), T.int64(16), T.int64(3)), "float32"), conv1d_ncw: T.Buffer((T.int64(2), T.int64(64), T.int64(13)), "float32")): T.func_attr({"tir.noalias": True}) @@ -80,7 +80,7 @@ class Expected: def main(x: R.Tensor((2, 3, 28), dtype="float32"), w: R.Tensor((4, 3, 3), dtype="float32")) -> R.Tensor((2, 4, 26), dtype="float16"): gv = R.call_tir(Expected.conv1d, (x, w), out_sinfo=R.Tensor((2, 4, 26), dtype="float16")) return gv - + @T.prim_func def conv1d(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(28)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(3)), "float32"), conv1d_ncw: T.Buffer((T.int64(2), T.int64(4), T.int64(26)), "float16")): T.func_attr({"tir.noalias": True}) @@ -121,7 +121,7 @@ class Expected: def main(x: R.Tensor((2, 28, 128), dtype="float32"), w: R.Tensor((64, 128, 3), dtype="float32")) -> R.Tensor((2, 26, 64), dtype="float32"): gv = R.call_tir(Expected.conv1d, (x, w), out_sinfo=R.Tensor((2, 26, 64), dtype="float32")) return gv - + @T.prim_func def conv1d(rxplaceholder: T.Buffer((T.int64(2), T.int64(28), T.int64(128)), "float32"), rxplaceholder_1: T.Buffer((T.int64(64), T.int64(128), T.int64(3)), "float32"), conv1d_nwc: T.Buffer((T.int64(2), T.int64(26), T.int64(64)), "float32")): T.func_attr({"tir.noalias": True}) @@ -171,7 +171,7 @@ def main(x: R.Tensor(("n", "c", "w"), dtype="float32"), kernel: R.Tensor(("f", " c = T.int64() gv = R.call_tir(Expected.conv1d, (x, kernel), out_sinfo=R.Tensor((n, f, w - kw + 1), dtype="float32")) return gv - + @T.prim_func def conv1d(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_conv1d_ncw: T.handle): T.func_attr({"tir.noalias": True}) From 20d79d858ac39c0ac475606b5b3d1b40db01f8e4 Mon Sep 17 00:00:00 2001 From: Lesheng Jin Date: Mon, 27 Mar 2023 11:22:18 -0700 Subject: [PATCH 5/5] fix test --- .../relax/test_tvmscript_parser_op_nn.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/python/relax/test_tvmscript_parser_op_nn.py b/tests/python/relax/test_tvmscript_parser_op_nn.py index 5e569cc4f44c..a822fae71922 100644 --- a/tests/python/relax/test_tvmscript_parser_op_nn.py +++ b/tests/python/relax/test_tvmscript_parser_op_nn.py @@ -38,13 +38,13 @@ def _check( def test_conv1d(): @R.function def foo( - x: R.Tensor((2, 3, 228), "float32"), w: R.Tensor((16, 3, 5), "float32") + x: R.Tensor((2, 3, 228), "float16"), w: R.Tensor((16, 3, 5), "float16") ) -> R.Tensor((2, 16, 224), "float16"): gv: R.Tensor((2, 16, 224), "float16") = R.nn.conv1d(x, w, out_dtype="float16") return gv - x = relax.Var("x", R.Tensor([2, 3, 228], "float32")) - w = relax.Var("w", R.Tensor([16, 3, 5], "float32")) + x = relax.Var("x", R.Tensor([2, 3, 228], "float16")) + w = relax.Var("w", R.Tensor([16, 3, 5], "float16")) bb = relax.BlockBuilder() with bb.function("foo", [x, w]): gv = bb.emit(relax.op.nn.conv1d(x, w, out_dtype="float16")) @@ -56,13 +56,13 @@ def foo( def test_conv2d(): @R.function def foo( - x: R.Tensor((2, 3, 228, 228), "float32"), w: R.Tensor((16, 3, 5, 5), "float32") + x: R.Tensor((2, 3, 228, 228), "float16"), w: R.Tensor((16, 3, 5, 5), "float16") ) -> R.Tensor((2, 16, 224, 224), "float16"): gv: R.Tensor((2, 16, 224, 224), "float16") = R.nn.conv2d(x, w, out_dtype="float16") return gv - x = relax.Var("x", R.Tensor([2, 3, 228, 228], "float32")) - w = relax.Var("w", R.Tensor([16, 3, 5, 5], "float32")) + x = relax.Var("x", R.Tensor([2, 3, 228, 228], "float16")) + w = relax.Var("w", R.Tensor([16, 3, 5, 5], "float16")) bb = relax.BlockBuilder() with bb.function("foo", [x, w]): gv = bb.emit(relax.op.nn.conv2d(x, w, out_dtype="float16")) @@ -74,15 +74,15 @@ def foo( def test_conv2d_transpose(): @R.function def foo( - x: R.Tensor((2, 3, 228, 228), "float32"), w: R.Tensor((3, 16, 5, 5), "float32") + x: R.Tensor((2, 3, 228, 228), "float16"), w: R.Tensor((3, 16, 5, 5), "float16") ) -> R.Tensor((2, 16, 232, 232), "float16"): gv: R.Tensor((2, 16, 232, 232), "float16") = R.nn.conv2d_transpose( x, w, out_dtype="float16" ) return gv - x = relax.Var("x", R.Tensor([2, 3, 228, 228], "float32")) - w = relax.Var("w", R.Tensor([3, 16, 5, 5], "float32")) + x = relax.Var("x", R.Tensor([2, 3, 228, 228], "float16")) + w = relax.Var("w", R.Tensor([3, 16, 5, 5], "float16")) bb = relax.BlockBuilder() with bb.function("foo", [x, w]): gv = bb.emit(relax.op.nn.conv2d_transpose(x, w, out_dtype="float16"))