diff --git a/include/tvm/relax/attrs/nn.h b/include/tvm/relax/attrs/nn.h index 694a51070683..61b1622a6082 100644 --- a/include/tvm/relax/attrs/nn.h +++ b/include/tvm/relax/attrs/nn.h @@ -174,6 +174,27 @@ struct LayerNormAttrs : public tvm::AttrsNode { } }; // struct LayerNormAttrs +/*! \brief Attributes used in group_norm operator */ +struct GroupNormAttrs : public tvm::AttrsNode { + int num_groups; + int channel_axis; + Array axes; + double epsilon; + bool center; + bool scale; + + TVM_DECLARE_ATTRS(GroupNormAttrs, "relax.attrs.GroupNormAttrs") { + TVM_ATTR_FIELD(num_groups).describe("The number of groups to separate the channels into."); + TVM_ATTR_FIELD(channel_axis).describe("The axis that represents the channel."); + TVM_ATTR_FIELD(axes).describe( + "The axes that along which the normalization is applied (excluding the channel axis)."); + TVM_ATTR_FIELD(epsilon).describe("Small float added to variance to avoid dividing by zero"); + TVM_ATTR_FIELD(center).describe( + "Indicating if the beta offset will be added to the normalized tensor."); + TVM_ATTR_FIELD(scale).describe("Indicating if the gamma scale will be multiplied."); + } +}; // struct GroupNormAttrs + /*! \brief Attributes used in dropout operator */ struct DropoutAttrs : public tvm::AttrsNode { double rate; diff --git a/include/tvm/topi/nn/group_norm.h b/include/tvm/topi/nn/group_norm.h new file mode 100644 index 000000000000..43760bab1fd0 --- /dev/null +++ b/include/tvm/topi/nn/group_norm.h @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \brief group normalization op constructions + * \file nn/group_norm.h + */ +#ifndef TVM_TOPI_NN_GROUP_NORM_H_ +#define TVM_TOPI_NN_GROUP_NORM_H_ + +#include +#include + +#include +#include +#include + +namespace tvm { +namespace topi { +namespace nn { + +using namespace tvm::te; + +inline Tensor group_norm(const Tensor& data, const Tensor& gamma, const Tensor& beta, + int num_groups, int channel_axis, const Array& axes, + double epsilon, std::string name = "T_group_norm", + std::string tag = kInjective) { + // reshape data C -> G, C/G + int ndim = data->shape.size(); + channel_axis = GetRealAxis(ndim, {channel_axis})[0]; + + auto shape = data->shape; + auto group_size = floordiv(shape[channel_axis], num_groups); + auto new_shape = Array(); + for (int i = 0; i < ndim; ++i) { + if (i == channel_axis) { + new_shape.push_back(num_groups); + new_shape.push_back(group_size); + } else { + new_shape.push_back(shape[i]); + } + } + auto data_reshaped = reshape(data, new_shape); + // reshape gamma and beta, C -> G, C/G + Tensor gamma_reshaped; + if (gamma.defined()) { + gamma_reshaped = reshape(gamma, {num_groups, group_size}); + } + Tensor beta_reshaped; + if (beta.defined()) { + beta_reshaped = reshape(beta, {num_groups, group_size}); + } + + // get the new axes to normalize after reshape + std::vector new_axes{channel_axis + 1}; + for (auto axis : axes) { + int new_axis = GetRealAxis(ndim, {axis})[0]; + if (new_axis < channel_axis) { + new_axes.push_back(new_axis); + } else if (new_axis > channel_axis) { + new_axes.push_back(new_axis + 1); + } else { + ICHECK(false) << "axes can not contain channel axis"; + } + } + std::sort(new_axes.begin(), new_axes.end()); + + // sum x and x^2 + ndim = data_reshaped->shape.size(); + auto reduce_axes = MakeReduceAxes(new_axes, data_reshaped); + auto target_shape = + MakeReduceTargetShape(new_axes, data_reshaped, /*keepdims=*/false, /*atleast1d=*/true); + auto func = MakeTupleSumReducer(); + + auto compute = [ndim, &new_axes, &reduce_axes, &func, &data_reshaped](const Array& indices) { + Array eval_range; + int arg_counter = 0; + int red_counter = 0; + + for (int i = 0; i < ndim; ++i) { + if (std::find(new_axes.begin(), new_axes.end(), i) != new_axes.end()) { + // new_axes contains i + eval_range.push_back(reduce_axes[red_counter]); + red_counter++; + } else { + eval_range.push_back(indices[arg_counter]); + arg_counter++; + } + } + auto square = [](const PrimExpr& x) { return x * x; }; + return func({data_reshaped(eval_range), square(data_reshaped(eval_range))}, reduce_axes, + nullptr); + }; + + auto temp_x_x2 = + tvm::te::compute(target_shape, compute, data->op->name + "_red_temp", kCommReduce); + + auto temp_x = temp_x_x2[0]; + auto temp_x2 = temp_x_x2[1]; + auto reduce_extent = make_const(data->dtype, 1); + for (auto axis : new_axes) { + reduce_extent *= data_reshaped->shape[axis]; + } + auto group_norm_func = [&](const Array& indices) { + Array reduce_indices, non_reduce_indices, gamma_indices; + for (int i = 0, n = static_cast(indices.size()); i < n; ++i) { + if (std::find(new_axes.begin(), new_axes.end(), i) != new_axes.end()) { + reduce_indices.push_back(indices[i]); + } else { + non_reduce_indices.push_back(indices[i]); + } + } + gamma_indices = {indices[channel_axis], indices[channel_axis + 1]}; + auto mean = temp_x(non_reduce_indices) / reduce_extent; + auto var = temp_x2(non_reduce_indices) / reduce_extent - mean * mean; + auto group_norm = + (data_reshaped(indices) - mean) * tvm::rsqrt(var + make_const(data->dtype, epsilon)); + if (gamma.defined()) { + group_norm = topi::multiply(group_norm, gamma_reshaped(gamma_indices)); + } + if (beta.defined()) { + group_norm = topi::add(group_norm, beta_reshaped(gamma_indices)); + } + return group_norm; + }; + auto group_norm_out = tvm::te::compute(data_reshaped->shape, group_norm_func, name, tag); + auto group_norm_out_reshaped = reshape(group_norm_out, shape); + return group_norm_out_reshaped; +} + +} // namespace nn +} // namespace topi +} // namespace tvm + +#endif // TVM_TOPI_NN_GROUP_NORM_H_ diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index e80f73096c59..24fcf0caca64 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -465,44 +465,30 @@ def _layer_norm(self, node: fx.node.Node) -> relax.Var: ) def _group_norm(self, node: fx.node.Node) -> relax.Var: - # torch.nn.GroupNorm(num_groups, num_channels, eps=1e-05, - # affine=True, device=None, dtype=None) + import torch # type: ignore + x = self.env[node.args[0]] module = self.named_modules[node.target] - num_groups = module.num_groups - num_channels = module.num_channels - eps = module.eps - affine = module.affine - shape = self.shape_of(x) - assert len(shape) == 4 - N, C, H, W = shape[0], shape[1], shape[2], shape[3] - assert C == num_channels - assert C % num_groups == 0 - grouped_x = self.block_builder.emit( - relax.op.reshape(x, [N, num_groups, C // num_groups, H, W]) - ) - mean_x = self.block_builder.emit(relax.op.mean(grouped_x, [2, 3, 4], keepdims=True)) - sub_x = self.block_builder.emit(relax.op.subtract(grouped_x, mean_x)) - square_x = self.block_builder.emit(relax.op.multiply(sub_x, sub_x)) - sum_square_x = self.block_builder.emit(relax.op.sum(square_x, [2, 3, 4], keepdims=True)) - var_x = self._call_binary_op(relax.op.divide, sum_square_x, (C // num_groups * H * W).value) - var_x_eps = self._call_binary_op(relax.op.add, var_x, eps) - std_x = self.block_builder.emit(relax.op.sqrt(var_x_eps)) - norm_x = self.block_builder.emit(relax.op.divide(sub_x, std_x)) - - if affine: - weight = self.params[module.weight] - bias = self.params[module.bias] - weight_reshape = self.block_builder.emit( - relax.op.reshape(weight, (1, num_groups, C // num_groups, 1, 1)) - ) - bias_reshape = self.block_builder.emit( - relax.op.reshape(bias, (1, num_groups, C // num_groups, 1, 1)) + if module.affine: + gamma = self.params[module.weight] + beta = self.params[module.bias] + else: + gamma = relax.const(torch.ones_like(module.num_channels), x.checked_type) + beta = relax.const(torch.zeros_like(module.num_channels), x.checked_type) + + dim = len(self.shape_of(x)) + return self.block_builder.emit( + relax.op.nn.group_norm( + x, + gamma, + beta, + num_groups=module.num_groups, + channel_axis=1, + axes=list(range(2, dim)), + epsilon=module.eps, ) - norm_x = self.block_builder.emit(relax.op.multiply(norm_x, weight_reshape)) - norm_x = self.block_builder.emit(relax.op.add(norm_x, bias_reshape)) - return self.block_builder.emit(relax.op.reshape(norm_x, (N, C, H, W))) + ) def _embedding(self, node: fx.node.Node) -> relax.Var: x = self.env[node.args[0]] diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py index 2fef37249703..bbb1268f1c96 100644 --- a/python/tvm/relax/op/nn/nn.py +++ b/python/tvm/relax/op/nn/nn.py @@ -527,6 +527,64 @@ def layer_norm( return _ffi_api.layer_norm(data, gamma, beta, axes, epsilon, center, scale) # type: ignore +def group_norm( + data: Expr, + gamma: Expr, + beta: Expr, + num_groups: int, + channel_axis: int, + axes: Union[int, List[int]], + epsilon: float = 1e-5, + center: bool = True, + scale: bool = True, +) -> Expr: + r""" + Group normalization (Yuxin Wu and et al., 2016). + Applies group normalization to the n-dimensional input array. + This operator takes an n-dimensional input array. First separate the input array + into groups along the channel axis. Then apply layer normalization to each group. + + Parameters + ---------- + data : relax.Expr + Input to which group_norm will be applied. + + gamma : relax.Expr + The gamma scale factor. + + beta : relax.Expr + The beta offset factor. + + num_groups : int + Number of groups to separate the channels into. + + channel_axis : int + The index of the channel axis in the input data. + + axes : Union[int, List[int]] + The axes that along which the normalization is applied (excluding the group axis) + + epsilon : float + Small float added to variance to avoid dividing by zero. + + center : bool + Indicating if the beta offset will be added to the normalized tensor. + + scale : bool + Indicating if the gamma scale will be multiplied. + + Returns + ------- + result : relax.Expr + The computed result. + """ + if isinstance(axes, int): + axes = [axes] + return _ffi_api.group_norm( # type: ignore + data, gamma, beta, num_groups, channel_axis, axes, epsilon, center, scale + ) + + def dropout(data: Expr, rate: float = 0.5) -> Expr: """Applies the dropout operation to the input tensor. diff --git a/python/tvm/relax/transform/legalize_ops/nn.py b/python/tvm/relax/transform/legalize_ops/nn.py index 70bb2513dda3..a61e0cd09ee1 100644 --- a/python/tvm/relax/transform/legalize_ops/nn.py +++ b/python/tvm/relax/transform/legalize_ops/nn.py @@ -196,6 +196,20 @@ def _nn_layer_norm(bb: BlockBuilder, call: Call) -> Expr: ) +@register_legalize("relax.nn.group_norm") +def _nn_group_norm(bb: BlockBuilder, call: Call) -> Expr: + return bb.call_te( + topi.nn.group_norm, + call.args[0], + call.args[1], + call.args[2], + call.attrs.num_groups, + call.attrs.channel_axis, + call.attrs.axes, + call.attrs.epsilon, + ) + + @register_legalize("relax.nn.dropout") def _nn_dropout(bb: BlockBuilder, call: Call) -> Expr: logging.info("Dropout is handled by frontend translator at this moment and is not legalized.") diff --git a/python/tvm/topi/nn/__init__.py b/python/tvm/topi/nn/__init__.py index 8f081242fa10..80a21e65313e 100644 --- a/python/tvm/topi/nn/__init__.py +++ b/python/tvm/topi/nn/__init__.py @@ -39,6 +39,7 @@ from .qnn import * from .upsampling import * from .layer_norm import layer_norm +from .group_norm import group_norm from .local_response_norm import * from .bitserial_conv2d import * from .bitserial_dense import * diff --git a/python/tvm/topi/nn/group_norm.py b/python/tvm/topi/nn/group_norm.py new file mode 100644 index 000000000000..c6358b8bc6ff --- /dev/null +++ b/python/tvm/topi/nn/group_norm.py @@ -0,0 +1,52 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Layer normalization operator.""" +from .. import cpp + + +def group_norm(data, gamma, beta, num_groups, channel_axis, axes, epsilon=1e-5): + """Group normalization operator. + + Parameters + ---------- + data : tvm.te.Tensor + N-D with shape (d_0, d_1, ..., d_{N-1}) + + gamma: tvm.te.Tensor + 1-D with shape (r_0) where r_0 == d_{channel_axis} + + beta: tvm.te.Tensor + Optional, 1-D with shape (r_0) where r_0 == d_{channel_axis} + + num_groups : int + The number of groups + + channel_axis : int + The channel axis + + axes : list of int + Axis over the normalization applied, excluding the channel axis + + epsilon : float + The epsilon value to avoid division by zero. + + Returns + ------- + result : tvm.te.Tensor + N-D with shape (d_0, d_1, ..., d_{N-1}) + """ + return cpp.nn.group_norm(data, gamma, beta, num_groups, channel_axis, axes, epsilon) diff --git a/python/tvm/topi/testing/__init__.py b/python/tvm/topi/testing/__init__.py index 2922c30b505c..ef480905833c 100644 --- a/python/tvm/topi/testing/__init__.py +++ b/python/tvm/topi/testing/__init__.py @@ -44,6 +44,7 @@ from .roi_align_python import roi_align_nchw_python, roi_align_nhwc_python from .roi_pool_python import roi_pool_nchw_python from .layer_norm_python import layer_norm_python +from .group_norm_python import group_norm_python from .lrn_python import lrn_python from .l2_normalize_python import l2_normalize_python from .gather_python import gather_python diff --git a/python/tvm/topi/testing/group_norm_python.py b/python/tvm/topi/testing/group_norm_python.py new file mode 100644 index 000000000000..d1c0d4a6abcc --- /dev/null +++ b/python/tvm/topi/testing/group_norm_python.py @@ -0,0 +1,82 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals +"""Group normalization in python""" +import numpy as np + + +def group_norm_python(data, gamma, beta, num_groups, channel_axis, axes, epsilon=1e-5): + """Group normalization operator. + + Parameters + ---------- + data : tvm.te.Tensor + N-D with shape (d_0, d_1, ..., d_{N-1}) + + gamma: tvm.te.Tensor + 1-D with shape (r_0) where r_0 == d_{channel_axis} + + beta: tvm.te.Tensor + Optional, 1-D with shape (r_0) where r_0 == d_{channel_axis} + + num_groups : int + The number of groups + + channel_axis : int + The channel axis + + axes : list of int + Axis over the normalization applied, excluding the channel axis + + epsilon : float + The epsilon value to avoid division by zero. + + Returns + ------- + result : tvm.te.Tensor + N-D with shape (d_0, d_1, ..., d_{N-1}) + """ + old_shape = data.shape + new_shape = list(old_shape) + new_shape[channel_axis] = data.shape[channel_axis] // num_groups + new_shape.insert(channel_axis, num_groups) + data = np.reshape(data, new_shape) + new_axes = [channel_axis + 1] + for axis in axes: + if axis < channel_axis: + new_axes.append(axis) + else: + new_axes.append(axis + 1) + mean = np.mean(data, axis=tuple(new_axes), keepdims=True) + var = np.var(data, axis=tuple(new_axes), keepdims=True) + data = (data - mean) / np.sqrt(var + epsilon) + data = np.reshape(data, old_shape) + + gamma_broadcast_shape = [1 for _ in range(len(old_shape))] + gamma_broadcast_shape[channel_axis] = gamma.shape[0] + gamma = np.reshape(gamma, gamma_broadcast_shape) + + beta_broadcast_shape = [1 for _ in range(len(old_shape))] + beta_broadcast_shape[channel_axis] = beta.shape[0] + if beta is not None: + beta = np.reshape(beta, beta_broadcast_shape) + + data *= gamma + if beta is not None: + data += beta + + return data diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc index e63b3306f25d..430d2268cec3 100644 --- a/src/relax/op/nn/nn.cc +++ b/src/relax/op/nn/nn.cc @@ -233,6 +233,89 @@ TVM_REGISTER_OP("relax.nn.layer_norm") .add_argument("beta", "Tensor", "The beta offset factor.") .set_attr("FInferStructInfo", InferStructInfoLayerNorm); +/* relax.nn.group_norm */ +TVM_REGISTER_NODE_TYPE(GroupNormAttrs); + +Expr group_norm(Expr data, Expr gamma, Expr beta, int num_groups, int channel_axis, + Array axes, double epsilon, bool center, bool scale) { + ObjectPtr attrs = make_object(); + attrs->num_groups = num_groups; + attrs->channel_axis = channel_axis; + attrs->axes = std::move(axes); + attrs->epsilon = epsilon; + attrs->center = center; + attrs->scale = scale; + + static const Op& op = Op::Get("relax.nn.group_norm"); + return Call(op, {std::move(data), std::move(gamma), std::move(beta)}, Attrs{attrs}, {}); +} + +TVM_REGISTER_GLOBAL("relax.op.nn.group_norm").set_body_typed(group_norm); + +StructInfo InferStructInfoGroupNorm(const Call& call, const BlockBuilder& ctx) { + Op op = Downcast(call->op); + Array input_sinfo = GetInputTensorStructInfo(call, ctx); + const auto* attrs = call->attrs.as(); + + TensorStructInfo data_sinfo = input_sinfo[0]; + int channel_axis = -1; + if (!data_sinfo->IsUnknownNdim()) { + channel_axis = NormalizeAxis(call, ctx, data_sinfo->ndim, attrs->channel_axis); + std::vector axes = NormalizeAxes(call, ctx, data_sinfo->ndim, attrs->axes); + // channel_axis must be in axes. + if (std::find(axes.begin(), axes.end(), channel_axis) != axes.end()) { + ctx->ReportFatal(Diagnostic::Error(call) + << op + << " expects that channel_axis must not be in axes, but got channel_axis: " + << channel_axis << ", axes: " << attrs->axes); + } + } + if (!data_sinfo->IsUnknownDtype() && !data_sinfo->dtype.is_float()) { + ctx->ReportFatal(Diagnostic::Error(call) + << op << " expects that data must be float, but got " << data_sinfo->dtype); + } + arith::Analyzer* analyzer = ctx->GetAnalyzer(); + const auto* data_shape = data_sinfo->shape.as(); + if (data_shape != nullptr && channel_axis != -1 && + analyzer->CanProve(floormod(data_shape->values[channel_axis], attrs->num_groups) != 0)) { + ctx->ReportFatal(Diagnostic::Error(call) + << op << " expects that the size of channel_axis must be divisible by " + << attrs->num_groups << ", but got " << data_shape->values[channel_axis]); + } + for (int i = 1; i < static_cast(op->arguments.size()); ++i) { + if (input_sinfo[i]->dtype != data_sinfo->dtype) { + ctx->ReportFatal(Diagnostic::Error(call) + << op << " expects that all inputs must have the same dtype, but got " + << input_sinfo[i]->dtype << " and " << data_sinfo->dtype); + } else if (input_sinfo[i]->ndim != 1) { + ctx->ReportFatal(Diagnostic::Error(call) + << op << " expects that all inputs must have ndim=1, but got " + << input_sinfo[i]->ndim); + } else if (channel_axis != -1) { + const auto* shape = input_sinfo[i]->shape.as(); + if (shape != nullptr && data_shape != nullptr) { + PrimExpr channel_size = data_shape->values[channel_axis]; + PrimExpr input_size = shape->values[0]; + if (analyzer->CanProve(channel_size != input_size)) { + ctx->ReportFatal(Diagnostic::Error(call) + << op << " expects that the size of input " << i + << " must be equal to the size of channel_axis, but got " << input_size + << " and " << channel_size); + } + } + } + } + return data_sinfo; +} + +TVM_REGISTER_OP("relax.nn.group_norm") + .set_attrs_type() + .set_num_inputs(3) + .add_argument("data", "Tensor", "Input to which batch_norm will be applied.") + .add_argument("gamma", "Tensor", "The gamma scale factor.") + .add_argument("beta", "Tensor", "The beta offset factor.") + .set_attr("FInferStructInfo", InferStructInfoGroupNorm); + /* relax.nn.dropout */ TVM_REGISTER_NODE_TYPE(DropoutAttrs); diff --git a/src/relax/op/nn/nn.h b/src/relax/op/nn/nn.h index f13b930fc246..f578f89346f7 100644 --- a/src/relax/op/nn/nn.h +++ b/src/relax/op/nn/nn.h @@ -68,6 +68,10 @@ Expr batch_norm(Expr data, Expr gamma, Expr beta, Expr moving_mean, Expr moving_ Expr layer_norm(Expr data, Expr gamma, Expr beta, Array axes, double epsilon, bool center, bool scale); +/*! \brief Compute group normalization. */ +Expr group_norm(Expr data, Expr gamma, Expr beta, int num_groups, int channel_axis, + Array axes, double epsilon, bool center, bool scale); + /*! * \brief Applies the dropout operation to the input tensor. * \param data The input data to the operator. diff --git a/src/topi/nn.cc b/src/topi/nn.cc index 35dbf3a03e4f..3b2c11010ff1 100644 --- a/src/topi/nn.cc +++ b/src/topi/nn.cc @@ -29,6 +29,7 @@ #include #include #include +#include #include #include #include @@ -163,5 +164,11 @@ TVM_REGISTER_GLOBAL("topi.nn.layer_norm").set_body([](TVMArgs args, TVMRetValue* *rv = nn::layer_norm(args[0], args[1], args[2], args[3], static_cast(args[4])); }); +/* Ops from nn/group_norm.h */ +TVM_REGISTER_GLOBAL("topi.nn.group_norm").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = nn::group_norm(args[0], args[1], args[2], static_cast(args[3]), + static_cast(args[4]), args[5], static_cast(args[6])); +}); + } // namespace topi } // namespace tvm diff --git a/tests/python/relax/test_ast_printer.py b/tests/python/relax/test_ast_printer.py index ba3c930a456f..c21dbd2bd1f5 100644 --- a/tests/python/relax/test_ast_printer.py +++ b/tests/python/relax/test_ast_printer.py @@ -362,7 +362,7 @@ def f( y: R.Tensor(("m",), "float32"), r: R.Tensor(dtype="int64"), ) -> R.Object: - m = T.var("int64") + m = T.int64() z: R.Tensor((32, m), "float32") = R.multiply(x, y) w: R.Tensor = R.multiply(z, z) q: R.Tensor(ndim=2) = R.add(w, w) @@ -431,7 +431,7 @@ def test_call_tir(): # also from test_parser @R.function def foo(x: R.Tensor(("m", "n"), "float32")): - m, n = T.var("int64"), T.var("int64") + m, n = T.int64(), T.int64() gv0 = R.call_tir("test.op.identity", (x,), R.Tensor((m, n), dtype="float32")) return gv0 diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 137713869e91..73cfacf1e526 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -708,29 +708,19 @@ def main( w1: R.Tensor((3,), dtype="float32"), w2: R.Tensor((3,), dtype="float32"), ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): - # block 0 with R.dataflow(): - lv: R.Tensor((1, 3, 1, 10, 10), dtype="float32") = R.reshape( - input_1, (1, 3, 1, 10, 10) - ) - lv1: R.Tensor((1, 3, 1, 1, 1), dtype="float32") = R.mean( - lv, axis=[2, 3, 4], keepdims=True - ) - lv2: R.Tensor((1, 3, 1, 10, 10), dtype="float32") = R.subtract(lv, lv1) - lv3: R.Tensor((1, 3, 1, 10, 10), dtype="float32") = R.multiply(lv2, lv2) - lv4: R.Tensor((1, 3, 1, 1, 1), dtype="float32") = R.sum( - lv3, axis=[2, 3, 4], keepdims=True + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.group_norm( + input_1, + w1, + w2, + num_groups=3, + channel_axis=1, + axes=[2, 3], + epsilon=1.0000000000000001e-05, + center=True, + scale=True, ) - lv5: R.Tensor((1, 3, 1, 1, 1), dtype="float32") = R.divide(lv4, R.const(100.0)) - lv6: R.Tensor((1, 3, 1, 1, 1), dtype="float32") = R.add(lv5, R.const(1e-05)) - lv7: R.Tensor((1, 3, 1, 1, 1), dtype="float32") = R.sqrt(lv6) - lv8: R.Tensor((1, 3, 1, 10, 10), dtype="float32") = R.divide(lv2, lv7) - lv9: R.Tensor((1, 3, 1, 1, 1), dtype="float32") = R.reshape(w1, (1, 3, 1, 1, 1)) - lv10: R.Tensor((1, 3, 1, 1, 1), dtype="float32") = R.reshape(w2, (1, 3, 1, 1, 1)) - lv11: R.Tensor((1, 3, 1, 10, 10), dtype="float32") = R.multiply(lv8, lv9) - lv12: R.Tensor((1, 3, 1, 10, 10), dtype="float32") = R.add(lv11, lv10) - lv13: R.Tensor((1, 3, 10, 10), dtype="float32") = R.reshape(lv12, (1, 3, 10, 10)) - gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv13 + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv R.output(gv) return gv diff --git a/tests/python/relax/test_op_nn.py b/tests/python/relax/test_op_nn.py index 5294596cee34..51144784638a 100644 --- a/tests/python/relax/test_op_nn.py +++ b/tests/python/relax/test_op_nn.py @@ -849,6 +849,244 @@ def test_layer_norm_infer_struct_info_wrong_input_type(): bb.normalize(relax.op.nn.layer_norm(x0, gamma1, beta, axes=[-2, -1])) +def test_group_norm_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=4)) + x2 = relax.Var("x", R.Tensor("float32")) + x3 = relax.Var("x", R.Tensor((2, 3, 4, 5))) + gamma0 = relax.Var("gamma", R.Tensor((4,), "float32")) + gamma1 = relax.Var("gamma", R.Tensor("float32", ndim=1)) + gamma2 = relax.Var("gamma", R.Tensor((4,))) + beta0 = relax.Var("beta", R.Tensor((4,), "float32")) + beta1 = relax.Var("beta", R.Tensor((4,))) + + _check_inference( + bb, + relax.op.nn.group_norm(x0, gamma0, beta0, num_groups=2, channel_axis=-2, axes=[-1]), + relax.TensorStructInfo((2, 3, 4, 5), "float32"), + ) + _check_inference( + bb, + relax.op.nn.group_norm(x0, gamma0, beta0, num_groups=2, channel_axis=-2, axes=[-1]), + relax.TensorStructInfo((2, 3, 4, 5), "float32"), + ) + _check_inference( + bb, + relax.op.nn.group_norm(x1, gamma0, beta0, num_groups=2, channel_axis=-2, axes=[-1]), + relax.TensorStructInfo(dtype="float32", ndim=4), + ) + _check_inference( + bb, + relax.op.nn.group_norm(x2, gamma0, beta0, num_groups=2, channel_axis=-2, axes=[-1]), + relax.TensorStructInfo(dtype="float32"), + ) + _check_inference( + bb, + relax.op.nn.group_norm(x0, gamma1, beta0, num_groups=2, channel_axis=-2, axes=[-1]), + relax.TensorStructInfo((2, 3, 4, 5), dtype="float32"), + ) + _check_inference( + bb, + relax.op.nn.group_norm(x3, gamma2, beta1, num_groups=2, channel_axis=-2, axes=[-1]), + relax.TensorStructInfo((2, 3, 4, 5), dtype=""), + ) + + +def test_group_norm_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + a = tir.Var("a", "int64") + b = tir.Var("b", "int64") + c0 = tir.Var("c", "int64") + c1 = tir.Var("c", "int64") + x0 = relax.Var("x", R.Tensor((n, a, b, c0), "float32")) + x1 = relax.Var("x", R.Tensor((n, a, b, c1), "float32")) + x2 = relax.Var("x", R.Tensor("float32", ndim=4)) + gamma0 = relax.Var("gamma", R.Tensor((a,), "float32")) + gamma1 = relax.Var("gamma", R.Tensor((a,), "float32")) + beta = relax.Var("beta", R.Tensor((a,), "float32")) + + _check_inference( + bb, + relax.op.nn.group_norm(x0, gamma0, beta, num_groups=2, channel_axis=-3, axes=[-2, -1]), + relax.TensorStructInfo((n, a, b, c0), "float32"), + ) + _check_inference( + bb, + relax.op.nn.group_norm(x1, gamma0, beta, num_groups=2, channel_axis=-3, axes=[-2, -1]), + relax.TensorStructInfo((n, a, b, c1), "float32"), + ) + _check_inference( + bb, + relax.op.nn.group_norm(x0, gamma1, beta, num_groups=2, channel_axis=-3, axes=[-2, -1]), + relax.TensorStructInfo((n, a, b, c0), "float32"), + ) + _check_inference( + bb, + relax.op.nn.group_norm(x2, gamma0, beta, num_groups=2, channel_axis=-3, axes=[-2, -1]), + relax.TensorStructInfo(dtype="float32", ndim=4), + ) + _check_inference( + bb, + relax.op.nn.group_norm(x2, gamma1, beta, num_groups=2, channel_axis=-3, axes=[-2, -1]), + relax.TensorStructInfo(dtype="float32", ndim=4), + ) + + +def test_group_norm_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s0", relax.ShapeStructInfo(ndim=4)) + s1 = relax.Var("s1", relax.ShapeStructInfo()) + s2 = relax.Var("s2", relax.ShapeStructInfo(ndim=1)) + s3 = relax.Var("s3", relax.ShapeStructInfo(ndim=1)) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + gamma = relax.Var("gamma", relax.TensorStructInfo(s2, "float32")) + beta = relax.Var("beta", relax.TensorStructInfo(s3, "float32")) + + _check_inference( + bb, + relax.op.nn.group_norm(x0, gamma, beta, num_groups=2, channel_axis=-2, axes=[1, 3]), + relax.TensorStructInfo(s0, "float32"), + ) + _check_inference( + bb, + relax.op.nn.group_norm(x1, gamma, beta, num_groups=2, channel_axis=-2, axes=[1, 3]), + relax.TensorStructInfo(s1, "float32"), + ) + + +def test_group_norm_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float16")) + gamma0 = relax.Var("gamma", R.Tensor((3,), "float16")) + beta0 = relax.Var("beta", R.Tensor((3,), "float16")) + x1 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float64")) + gamma1 = relax.Var("gamma", R.Tensor((3,), "float64")) + beta1 = relax.Var("beta", R.Tensor((3,), "float64")) + + _check_inference( + bb, + relax.op.nn.group_norm(x0, gamma0, beta0, num_groups=3, channel_axis=1, axes=[-2, -1]), + relax.TensorStructInfo((2, 3, 4, 5), "float16"), + ) + _check_inference( + bb, + relax.op.nn.group_norm(x1, gamma1, beta1, num_groups=3, channel_axis=1, axes=[-2, -1]), + relax.TensorStructInfo((2, 3, 4, 5), "float64"), + ) + + +def test_group_norm_infer_struct_info_invalid_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "int8")) + gamma0 = relax.Var("gamma", R.Tensor((4,), "int8")) + beta0 = relax.Var("beta", R.Tensor((4,), "int8")) + x1 = relax.Var("x", R.Tensor((2, 3, 4, 5), "int32")) + gamma1 = relax.Var("gamma", R.Tensor((4,), "int32")) + beta1 = relax.Var("beta", R.Tensor((4,), "int32")) + + with pytest.raises(TVMError): + bb.normalize( + relax.op.nn.group_norm(x0, gamma0, beta0, num_groups=2, channel_axis=-2, axes=[-2, -1]) + ) + with pytest.raises(TVMError): + bb.normalize( + relax.op.nn.group_norm(x1, gamma1, beta1, num_groups=2, channel_axis=-2, axes=[-2, -1]) + ) + + +def test_group_norm_infer_struct_info_axis_out_of_range_and_repetitive(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + gamma = relax.Var("gamma", R.Tensor((4,), "float32")) + beta = relax.Var("beta", R.Tensor((4,), "float32")) + + with pytest.raises(TVMError): + bb.normalize( + relax.op.nn.group_norm(x, gamma, beta, num_groups=2, channel_axis=-2, axes=[3, 4]) + ) + with pytest.raises(TVMError): + bb.normalize( + relax.op.nn.group_norm(x, gamma, beta, num_groups=2, channel_axis=-2, axes=[3, -1]) + ) + + +def test_group_norm_infer_struct_info_dtype_mismatch(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + gamma0 = relax.Var("gamma", R.Tensor((4,), "float32")) + gamma1 = relax.Var("gamma", R.Tensor((4,), "int8")) + beta0 = relax.Var("beta", R.Tensor((4,), "float32")) + beta1 = relax.Var("beta", R.Tensor((4,))) + + with pytest.raises(TVMError): + bb.normalize( + relax.op.nn.group_norm(x, gamma1, beta0, num_groups=2, channel_axis=-2, axes=[-2, -1]) + ) + with pytest.raises(TVMError): + bb.normalize( + relax.op.nn.group_norm(x, gamma0, beta1, num_groups=2, channel_axis=-2, axes=[-2, -1]) + ) + + +def test_group_norm_infer_struct_info_ndim_mismatch(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + gamma0 = relax.Var("gamma", R.Tensor((4, 5), "float32")) + gamma1 = relax.Var("gamma", R.Tensor((4,), "float32")) + beta0 = relax.Var("beta", R.Tensor((4, 5), "float32")) + beta1 = relax.Var("beta", R.Tensor((3, 4, 5), "float32")) + + with pytest.raises(TVMError): + bb.normalize( + relax.op.nn.group_norm(x, gamma1, beta0, num_groups=2, channel_axis=-2, axes=[-2, -1]) + ) + with pytest.raises(TVMError): + bb.normalize( + relax.op.nn.group_norm(x, gamma0, beta1, num_groups=2, channel_axis=-2, axes=[-2, -1]) + ) + + +def test_group_norm_infer_struct_info_shape_mismatch(): + bb = relax.BlockBuilder() + c0 = tir.Var("c", "int64") + x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + x1 = relax.Var("x", R.Tensor((2, 3, 4, c0), "float32")) + gamma0 = relax.Var("gamma", R.Tensor((4, 6), "float32")) + gamma1 = relax.Var("gamma", R.Tensor((4, c0), "float32")) + beta0 = relax.Var("beta", R.Tensor((4, 5), "float32")) + beta1 = relax.Var("beta", R.Tensor((4, c0 - 2), "float32")) + + with pytest.raises(TVMError): + bb.normalize( + relax.op.nn.group_norm(x0, gamma0, beta0, num_groups=2, channel_axis=-2, axes=[-2, -1]) + ) + with pytest.raises(TVMError): + bb.normalize( + relax.op.nn.group_norm(x1, gamma1, beta1, num_groups=2, channel_axis=-2, axes=[-2, -1]) + ) + + +def test_group_norm_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + x1 = relax.Var("x", relax.ShapeStructInfo((2, 3, 4, 5))) + gamma0 = relax.Var("gamma", R.Tensor((4, 5), "float32")) + gamma1 = relax.Var("gamma", relax.FuncStructInfo([], R.Tensor((4, 5), "float32"))) + beta = relax.Var("beta", R.Tensor((4, 5), "float32")) + + with pytest.raises(TVMError): + bb.normalize( + relax.op.nn.group_norm(x1, gamma0, beta, num_groups=2, channel_axis=-2, axes=[-2, -1]) + ) + with pytest.raises(TVMError): + bb.normalize( + relax.op.nn.group_norm(x0, gamma1, beta, num_groups=2, channel_axis=-2, axes=[-2, -1]) + ) + + def test_dropout_infer_struct_info(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3), "float32")) diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py b/tests/python/relax/test_transform_legalize_ops_nn.py index 698ad2727456..8fb398f15d2b 100644 --- a/tests/python/relax/test_transform_legalize_ops_nn.py +++ b/tests/python/relax/test_transform_legalize_ops_nn.py @@ -1452,5 +1452,167 @@ def layer_norm(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_r tvm.ir.assert_structural_equal(mod, Expected) +def test_group_norm(): + # fmt: off + @tvm.script.ir_module + class GroupNorm: + @R.function + def main(x: R.Tensor((2, 4, 4, 5), "float32"), gamma: R.Tensor((4,), "float32"), beta: R.Tensor((4,), "float32")) -> R.Tensor((2, 4, 4, 5), "float32"): + gv: R.Tensor((2, 4, 4, 5), "float32") = R.nn.group_norm(x, gamma, beta, num_groups=2, channel_axis=1, axes=[2, 3]) + return gv + + @tvm.script.ir_module + class Expected: + @T.prim_func + def group_norm(rxplaceholder: T.Buffer((T.int64(2), T.int64(4), T.int64(4), T.int64(5)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4),), "float32"), rxplaceholder_2: T.Buffer((T.int64(4),), "float32"), T_reshape: T.Buffer((T.int64(2), T.int64(4), T.int64(4), T.int64(5)), "float32")): + T.func_attr({"tir.noalias": True}) + T_reshape_1 = T.alloc_buffer((T.int64(2), T.int64(2), T.int64(2), T.int64(4), T.int64(5))) + rxplaceholder_red_temp_v0 = T.alloc_buffer((T.int64(2), T.int64(2))) + rxplaceholder_red_temp_v1 = T.alloc_buffer((T.int64(2), T.int64(2))) + T_reshape_2 = T.alloc_buffer((T.int64(2), T.int64(2))) + T_reshape_3 = T.alloc_buffer((T.int64(2), T.int64(2))) + T_group_norm = T.alloc_buffer((T.int64(2), T.int64(2), T.int64(2), T.int64(4), T.int64(5))) + for ax0, ax1, ax2, ax3, ax4 in T.grid(T.int64(2), T.int64(2), T.int64(2), T.int64(4), T.int64(5)): + with T.block("T_reshape"): + v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap("SSSSS", [ax0, ax1, ax2, ax3, ax4]) + T.reads(rxplaceholder[((v_ax1 * T.int64(2) + (v_ax4 // T.int64(5) + v_ax3) // T.int64(4) + v_ax2) // T.int64(4) + v_ax0) % T.int64(2), (v_ax1 * T.int64(2) + (v_ax4 // T.int64(5) + v_ax3) // T.int64(4) + v_ax2) % T.int64(4), (v_ax4 // T.int64(5) + v_ax3) % T.int64(4), v_ax4 % T.int64(5)]) + T.writes(T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]) + T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = rxplaceholder[((v_ax1 * T.int64(2) + (v_ax4 // T.int64(5) + v_ax3) // T.int64(4) + v_ax2) // T.int64(4) + v_ax0) % T.int64(2), (v_ax1 * T.int64(2) + (v_ax4 // T.int64(5) + v_ax3) // T.int64(4) + v_ax2) % T.int64(4), (v_ax4 // T.int64(5) + v_ax3) % T.int64(4), v_ax4 % T.int64(5)] + for ax0, ax1, k2, k3, k4 in T.grid(T.int64(2), T.int64(2), T.int64(2), T.int64(4), T.int64(5)): + with T.block("rxplaceholder_red_temp"): + v_ax0, v_ax1, v_k2, v_k3, v_k4 = T.axis.remap("SSRRR", [ax0, ax1, k2, k3, k4]) + T.reads(T_reshape_1[v_ax0, v_ax1, v_k2, v_k3, v_k4]) + T.writes(rxplaceholder_red_temp_v0[v_ax0, v_ax1], rxplaceholder_red_temp_v1[v_ax0, v_ax1]) + with T.init(): + rxplaceholder_red_temp_v0[v_ax0, v_ax1] = T.float32(0) + rxplaceholder_red_temp_v1[v_ax0, v_ax1] = T.float32(0) + v_rxplaceholder_red_temp_v0: T.float32 = rxplaceholder_red_temp_v0[v_ax0, v_ax1] + T_reshape_1[v_ax0, v_ax1, v_k2, v_k3, v_k4] + v_rxplaceholder_red_temp_v1: T.float32 = rxplaceholder_red_temp_v1[v_ax0, v_ax1] + T_reshape_1[v_ax0, v_ax1, v_k2, v_k3, v_k4] * T_reshape_1[v_ax0, v_ax1, v_k2, v_k3, v_k4] + rxplaceholder_red_temp_v0[v_ax0, v_ax1] = v_rxplaceholder_red_temp_v0 + rxplaceholder_red_temp_v1[v_ax0, v_ax1] = v_rxplaceholder_red_temp_v1 + for ax0, ax1 in T.grid(T.int64(2), T.int64(2)): + with T.block("T_reshape_1"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(rxplaceholder_1[(v_ax0 * T.int64(2) + v_ax1) % T.int64(4)]) + T.writes(T_reshape_2[v_ax0, v_ax1]) + T_reshape_2[v_ax0, v_ax1] = rxplaceholder_1[(v_ax0 * T.int64(2) + v_ax1) % T.int64(4)] + for ax0, ax1 in T.grid(T.int64(2), T.int64(2)): + with T.block("T_reshape_2"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(rxplaceholder_2[(v_ax0 * T.int64(2) + v_ax1) % T.int64(4)]) + T.writes(T_reshape_3[v_ax0, v_ax1]) + T_reshape_3[v_ax0, v_ax1] = rxplaceholder_2[(v_ax0 * T.int64(2) + v_ax1) % T.int64(4)] + for ax0, ax1, ax2, ax3, ax4 in T.grid(T.int64(2), T.int64(2), T.int64(2), T.int64(4), T.int64(5)): + with T.block("T_group_norm"): + v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap("SSSSS", [ax0, ax1, ax2, ax3, ax4]) + T.reads(T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4], rxplaceholder_red_temp_v0[v_ax0, v_ax1], rxplaceholder_red_temp_v1[v_ax0, v_ax1], T_reshape_2[v_ax1, v_ax2], T_reshape_3[v_ax1, v_ax2]) + T.writes(T_group_norm[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]) + T_group_norm[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = (T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] - rxplaceholder_red_temp_v0[v_ax0, v_ax1] * T.float32(0.025000000000000001)) * T.rsqrt(rxplaceholder_red_temp_v1[v_ax0, v_ax1] * T.float32(0.025000000000000001) - rxplaceholder_red_temp_v0[v_ax0, v_ax1] * T.float32(0.025000000000000001) * (rxplaceholder_red_temp_v0[v_ax0, v_ax1] * T.float32(0.025000000000000001)) + T.float32(1.0000000000000001e-05)) * T_reshape_2[v_ax1, v_ax2] + T_reshape_3[v_ax1, v_ax2] + for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(4), T.int64(4), T.int64(5)): + with T.block("T_reshape_3"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(T_group_norm[(((v_ax3 // T.int64(5) + v_ax2) // T.int64(4) + v_ax1) // T.int64(4) + v_ax0) % T.int64(2), ((v_ax3 // T.int64(5) + v_ax2) // T.int64(4) + v_ax1) % T.int64(4) // T.int64(2), ((v_ax3 // T.int64(5) + v_ax2) // T.int64(4) + v_ax1) % T.int64(2), (v_ax3 // T.int64(5) + v_ax2) % T.int64(4), v_ax3 % T.int64(5)]) + T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3]) + T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = T_group_norm[(((v_ax3 // T.int64(5) + v_ax2) // T.int64(4) + v_ax1) // T.int64(4) + v_ax0) % T.int64(2), ((v_ax3 // T.int64(5) + v_ax2) // T.int64(4) + v_ax1) % T.int64(4) // T.int64(2), ((v_ax3 // T.int64(5) + v_ax2) // T.int64(4) + v_ax1) % T.int64(2), (v_ax3 // T.int64(5) + v_ax2) % T.int64(4), v_ax3 % T.int64(5)] + + @R.function + def main(x: R.Tensor((2, 4, 4, 5), dtype="float32"), gamma: R.Tensor((4,), dtype="float32"), beta: R.Tensor((4,), dtype="float32")) -> R.Tensor((2, 4, 4, 5), dtype="float32"): + gv = R.call_tir(group_norm, (x, gamma, beta), out_sinfo=R.Tensor((2, 4, 4, 5), dtype="float32")) + return gv + # fmt: on + + mod = LegalizeOps()(GroupNorm) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_group_norm_symbolic(): + # fmt: off + @tvm.script.ir_module + class GroupNorm: + @R.function + def main(s: R.Shape(["c"]), x: R.Tensor(("n", "4 * c", "h", "w"), "float32"), gamma: R.Tensor(("4 * c",), "float32"), beta: R.Tensor(("4 * c",), "float32")) -> R.Tensor(("n", "4 * c", "h", "w"), "float32"): + n = T.int64() + c = T.int64() + h = T.int64() + w = T.int64() + gv: R.Tensor((n, 4 * c, h, w), "float32") = R.nn.group_norm(x, gamma, beta, num_groups=4, channel_axis=1, axes=[2, 3]) + return gv + + @tvm.script.ir_module + class Expected: + @T.prim_func + def group_norm(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_rxplaceholder_2: T.handle, var_T_reshape: T.handle, c: T.int64): + T.func_attr({"tir.noalias": True}) + n = T.int64() + h = T.int64() + w = T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, (n, T.int64(4) * c, h, w)) + rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, (T.int64(4) * c,)) + rxplaceholder_2 = T.match_buffer(var_rxplaceholder_2, (T.int64(4) * c,)) + T_reshape = T.match_buffer(var_T_reshape, (n, T.int64(4) * c, h, w)) + # with T.block("root"): + T_reshape_1 = T.alloc_buffer((n, T.int64(4), T.int64(4) * c // T.int64(4), h, w)) + rxplaceholder_red_temp_v0 = T.alloc_buffer((n, T.int64(4))) + rxplaceholder_red_temp_v1 = T.alloc_buffer((n, T.int64(4))) + T_reshape_2 = T.alloc_buffer((T.int64(4), T.int64(4) * c // T.int64(4))) + T_reshape_3 = T.alloc_buffer((T.int64(4), T.int64(4) * c // T.int64(4))) + T_group_norm = T.alloc_buffer((n, T.int64(4), T.int64(4) * c // T.int64(4), h, w)) + for ax0, ax1, ax2, ax3, ax4 in T.grid(n, T.int64(4), c, h, w): + with T.block("T_reshape"): + v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap("SSSSS", [ax0, ax1, ax2, ax3, ax4]) + T.reads(rxplaceholder[((((v_ax0 * T.int64(4) + v_ax1) * c + v_ax2) * h + v_ax3) * w + v_ax4) // w // h // (c * T.int64(4)) % n, ((((v_ax0 * T.int64(4) + v_ax1) * c + v_ax2) * h + v_ax3) * w + v_ax4) // w // h % (c * T.int64(4)), ((((v_ax0 * T.int64(4) + v_ax1) * c + v_ax2) * h + v_ax3) * w + v_ax4) // w % h, ((((v_ax0 * T.int64(4) + v_ax1) * c + v_ax2) * h + v_ax3) * w + v_ax4) % w]) + T.writes(T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]) + T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = rxplaceholder[((((v_ax0 * T.int64(4) + v_ax1) * c + v_ax2) * h + v_ax3) * w + v_ax4) // w // h // (c * T.int64(4)) % n, ((((v_ax0 * T.int64(4) + v_ax1) * c + v_ax2) * h + v_ax3) * w + v_ax4) // w // h % (c * T.int64(4)), ((((v_ax0 * T.int64(4) + v_ax1) * c + v_ax2) * h + v_ax3) * w + v_ax4) // w % h, ((((v_ax0 * T.int64(4) + v_ax1) * c + v_ax2) * h + v_ax3) * w + v_ax4) % w] + for ax0, ax1, k2, k3, k4 in T.grid(n, T.int64(4), c, h, w): + with T.block("rxplaceholder_red_temp"): + v_ax0, v_ax1, v_k2, v_k3, v_k4 = T.axis.remap("SSRRR", [ax0, ax1, k2, k3, k4]) + T.reads(T_reshape_1[v_ax0, v_ax1, v_k2, v_k3, v_k4]) + T.writes(rxplaceholder_red_temp_v0[v_ax0, v_ax1], rxplaceholder_red_temp_v1[v_ax0, v_ax1]) + with T.init(): + rxplaceholder_red_temp_v0[v_ax0, v_ax1] = T.float32(0) + rxplaceholder_red_temp_v1[v_ax0, v_ax1] = T.float32(0) + v_rxplaceholder_red_temp_v0: T.float32 = rxplaceholder_red_temp_v0[v_ax0, v_ax1] + T_reshape_1[v_ax0, v_ax1, v_k2, v_k3, v_k4] + v_rxplaceholder_red_temp_v1: T.float32 = rxplaceholder_red_temp_v1[v_ax0, v_ax1] + T_reshape_1[v_ax0, v_ax1, v_k2, v_k3, v_k4] * T_reshape_1[v_ax0, v_ax1, v_k2, v_k3, v_k4] + rxplaceholder_red_temp_v0[v_ax0, v_ax1] = v_rxplaceholder_red_temp_v0 + rxplaceholder_red_temp_v1[v_ax0, v_ax1] = v_rxplaceholder_red_temp_v1 + for ax0, ax1 in T.grid(T.int64(4), c): + with T.block("T_reshape_1"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(rxplaceholder_1[(v_ax0 * c + v_ax1) % (c * T.int64(4))]) + T.writes(T_reshape_2[v_ax0, v_ax1]) + T_reshape_2[v_ax0, v_ax1] = rxplaceholder_1[(v_ax0 * c + v_ax1) % (c * T.int64(4))] + for ax0, ax1 in T.grid(T.int64(4), c): + with T.block("T_reshape_2"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(rxplaceholder_2[(v_ax0 * c + v_ax1) % (c * T.int64(4))]) + T.writes(T_reshape_3[v_ax0, v_ax1]) + T_reshape_3[v_ax0, v_ax1] = rxplaceholder_2[(v_ax0 * c + v_ax1) % (c * T.int64(4))] + for ax0, ax1, ax2, ax3, ax4 in T.grid(n, T.int64(4), c, h, w): + with T.block("T_group_norm"): + v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap("SSSSS", [ax0, ax1, ax2, ax3, ax4]) + T.reads(T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4], rxplaceholder_red_temp_v0[v_ax0, v_ax1], rxplaceholder_red_temp_v1[v_ax0, v_ax1], T_reshape_2[v_ax1, v_ax2], T_reshape_3[v_ax1, v_ax2]) + T.writes(T_group_norm[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]) + T_group_norm[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = (T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] - rxplaceholder_red_temp_v0[v_ax0, v_ax1] / (T.Cast("float32", c) * T.Cast("float32", h) * T.Cast("float32", w))) * T.rsqrt(rxplaceholder_red_temp_v1[v_ax0, v_ax1] / (T.Cast("float32", c) * T.Cast("float32", h) * T.Cast("float32", w)) - rxplaceholder_red_temp_v0[v_ax0, v_ax1] / (T.Cast("float32", c) * T.Cast("float32", h) * T.Cast("float32", w)) * (rxplaceholder_red_temp_v0[v_ax0, v_ax1] / (T.Cast("float32", c) * T.Cast("float32", h) * T.Cast("float32", w))) + T.float32(1.0000000000000001e-05)) * T_reshape_2[v_ax1, v_ax2] + T_reshape_3[v_ax1, v_ax2] + for ax0, ax1, ax2, ax3 in T.grid(n, c * T.int64(4), h, w): + with T.block("T_reshape_3"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(T_group_norm[(((v_ax0 * (c * T.int64(4)) + v_ax1) * h + v_ax2) * w + v_ax3) // w // h // c // T.int64(4) % n, (((v_ax0 * (c * T.int64(4)) + v_ax1) * h + v_ax2) * w + v_ax3) // w // h // c % T.int64(4), (((v_ax0 * (c * T.int64(4)) + v_ax1) * h + v_ax2) * w + v_ax3) // w // h % c, (((v_ax0 * (c * T.int64(4)) + v_ax1) * h + v_ax2) * w + v_ax3) // w % h, (((v_ax0 * (c * T.int64(4)) + v_ax1) * h + v_ax2) * w + v_ax3) % w]) + T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3]) + T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = T_group_norm[(((v_ax0 * (c * T.int64(4)) + v_ax1) * h + v_ax2) * w + v_ax3) // w // h // c // T.int64(4) % n, (((v_ax0 * (c * T.int64(4)) + v_ax1) * h + v_ax2) * w + v_ax3) // w // h // c % T.int64(4), (((v_ax0 * (c * T.int64(4)) + v_ax1) * h + v_ax2) * w + v_ax3) // w // h % c, (((v_ax0 * (c * T.int64(4)) + v_ax1) * h + v_ax2) * w + v_ax3) // w % h, (((v_ax0 * (c * T.int64(4)) + v_ax1) * h + v_ax2) * w + v_ax3) % w] + + @R.function + def main(s: R.Shape(["c"]), x: R.Tensor(("n", "4 * c", "h", "w"), dtype="float32"), gamma: R.Tensor(("4 * c",), dtype="float32"), beta: R.Tensor(("4 * c",), dtype="float32")) -> R.Tensor(("n", "4 * c", "h", "w"), dtype="float32"): + n = T.int64() + c = T.int64() + h = T.int64() + w = T.int64() + gv = R.call_tir(group_norm, (x, gamma, beta), out_sinfo=R.Tensor((n, 4 * c, h, w), dtype="float32"), tir_vars=R.shape([c])) + return gv + # fmt: on + + mod = LegalizeOps()(GroupNorm) + tvm.ir.assert_structural_equal(mod, Expected) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/test_tvmscript_parser_op_nn.py b/tests/python/relax/test_tvmscript_parser_op_nn.py index 781700af7b82..c2bfa5b7a9e9 100644 --- a/tests/python/relax/test_tvmscript_parser_op_nn.py +++ b/tests/python/relax/test_tvmscript_parser_op_nn.py @@ -185,6 +185,31 @@ def foo( _check(foo, bb.get()["foo"]) +def test_group_norm(): + @R.function + def foo( + x: R.Tensor((2, 4, 4, 5), "float32"), + gamma: R.Tensor((4,), "float32"), + beta: R.Tensor((4,), "float32"), + ) -> R.Tensor((2, 4, 4, 5), "float32"): + gv: R.Tensor((2, 4, 4, 5), "float32") = R.nn.group_norm( + x, gamma, beta, num_groups=2, channel_axis=1, axes=[2, 3] + ) + return gv + + x = relax.Var("x", R.Tensor((2, 4, 4, 5), "float32")) + gamma = relax.Var("gamma", R.Tensor((4,), "float32")) + beta = relax.Var("beta", R.Tensor((4,), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x, gamma, beta]): + gv = bb.emit( + relax.op.nn.group_norm(x, gamma, beta, num_groups=2, channel_axis=1, axes=[2, 3]) + ) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + def test_dropout(): @R.function def foo( diff --git a/tests/python/topi/python/test_topi_group_norm.py b/tests/python/topi/python/test_topi_group_norm.py new file mode 100644 index 000000000000..f09442391672 --- /dev/null +++ b/tests/python/topi/python/test_topi_group_norm.py @@ -0,0 +1,66 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Test code for group_norm.""" +import numpy as np +import pytest +import tvm +from tvm import te +from tvm import topi +from tvm.topi.utils import get_const_tuple +import tvm.topi.testing + +import tvm.testing + + +_group_norm_schedule = { + "generic": topi.generic.schedule_injective, +} + + +# only test on llvm because schedule is missing +@tvm.testing.parametrize_targets("llvm") +@pytest.mark.parametrize("shape, axis", [([2, 4, 16], (2,)), ([2, 4, 4, 16], (2, 3))]) +def test_group_norm(target, dev, shape, axis, epsilon=1e-5, dtype="float32", rtol=1e-5, atol=1e-5): + data = te.placeholder(shape, dtype=dtype, name="data") + num_groups = 2 + channel_axis = 1 + gamma = te.placeholder((shape[channel_axis],), dtype=dtype, name="gamma") + beta = te.placeholder((shape[channel_axis],), dtype=dtype, name="beta") + B = topi.nn.group_norm(data, gamma, beta, num_groups, channel_axis, axis, epsilon) + + np.random.seed(0) + data_np = np.random.uniform(size=shape).astype(dtype) + gamma_np = np.random.uniform(size=(shape[channel_axis],)).astype(dtype) + beta_np = np.random.uniform(size=(shape[channel_axis],)).astype(dtype) + b_np = tvm.topi.testing.group_norm_python( + data_np, gamma_np, beta_np, num_groups, channel_axis, axis, epsilon + ) + + with tvm.target.Target(target): + s_func = tvm.topi.testing.dispatch(target, _group_norm_schedule) + s = s_func([B]) + data_tvm = tvm.nd.array(data_np, dev) + gamma_tvm = tvm.nd.array(gamma_np, dev) + beta_tvm = tvm.nd.array(beta_np, dev) + b_tvm = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), dev) + f = tvm.build(s, [data, gamma, beta, B], target) + f(data_tvm, gamma_tvm, beta_tvm, b_tvm) + tvm.testing.assert_allclose(b_tvm.numpy(), b_np, rtol=rtol, atol=atol) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/topi/python/test_topi_layer_norm.py b/tests/python/topi/python/test_topi_layer_norm.py index ead05470be3b..f875bb09e2a4 100644 --- a/tests/python/topi/python/test_topi_layer_norm.py +++ b/tests/python/topi/python/test_topi_layer_norm.py @@ -55,7 +55,7 @@ def test_layer_norm(target, dev, shape, axis, episilon=1e-5, dtype="float32", rt b_tvm = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), dev) f = tvm.build(s, [data, gamma, beta, B], target) f(data_tvm, gamma_tvm, beta_tvm, b_tvm) - tvm.testing.assert_allclose(b_tvm.asnumpy(), b_np, rtol=rtol, atol=atol) + tvm.testing.assert_allclose(b_tvm.numpy(), b_np, rtol=rtol, atol=atol) if __name__ == "__main__":