From d6b2c14a1b9afb978a1af36fb08fb73fc3f15504 Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Fri, 14 Jul 2023 07:12:40 +0000 Subject: [PATCH 1/2] [OP] Add `rms_norm` This PR introduces the operator root mean square, `rms_norm`, into TOPI and relax, and its legalize transform. --- include/tvm/relax/attrs/nn.h | 11 + include/tvm/topi/nn/rms_norm.h | 94 +++++++ python/tvm/relax/op/nn/nn.py | 40 +++ python/tvm/relax/transform/legalize_ops/nn.py | 11 + python/tvm/topi/nn/__init__.py | 1 + python/tvm/topi/nn/rms_norm.py | 45 +++ python/tvm/topi/testing/__init__.py | 1 + python/tvm/topi/testing/rms_norm_python.py | 48 ++++ src/relax/op/nn/nn.cc | 59 ++++ src/relax/op/nn/nn.h | 3 + src/topi/nn.cc | 6 + .../relax/test_transform_legalize_ops_nn.py | 260 ++++++++++++++++++ .../python/topi/python/test_topi_rms_norm.py | 60 ++++ 13 files changed, 639 insertions(+) create mode 100644 include/tvm/topi/nn/rms_norm.h create mode 100644 python/tvm/topi/nn/rms_norm.py create mode 100644 python/tvm/topi/testing/rms_norm_python.py create mode 100644 tests/python/topi/python/test_topi_rms_norm.py diff --git a/include/tvm/relax/attrs/nn.h b/include/tvm/relax/attrs/nn.h index b759ce6c2686..38f84f1856e1 100644 --- a/include/tvm/relax/attrs/nn.h +++ b/include/tvm/relax/attrs/nn.h @@ -287,6 +287,17 @@ struct GroupNormAttrs : public tvm::AttrsNode { } }; // struct GroupNormAttrs +/*! \brief Attributes used in rms_norm operator */ +struct RMSNormAttrs : public tvm::AttrsNode { + Array axes; + double epsilon; + + TVM_DECLARE_ATTRS(RMSNormAttrs, "relax.attrs.RMSNormAttrs") { + TVM_ATTR_FIELD(axes).describe("The axes that along which the normalization is applied."); + TVM_ATTR_FIELD(epsilon).describe("Small float added to variance to avoid dividing by zero"); + } +}; // struct RMSNormAttrs + /*! \brief Attributes used in nll_loss operator */ struct NLLLossAttrs : public tvm::AttrsNode { String reduction; diff --git a/include/tvm/topi/nn/rms_norm.h b/include/tvm/topi/nn/rms_norm.h new file mode 100644 index 000000000000..e743205611c3 --- /dev/null +++ b/include/tvm/topi/nn/rms_norm.h @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \brief root mean square normalization op constructions + * \file nn/rms_norm.h + */ +#ifndef TVM_TOPI_NN_RMS_NORM_H_ +#define TVM_TOPI_NN_RMS_NORM_H_ + +#include +#include +#include + +#include + +namespace tvm { +namespace topi { +namespace nn { + +using namespace tvm::te; + +/*! + * \brief Root mean square normalization. + * \param data N-D tensor with shape [d_0, d_1, ..., d_{N-1}] + * \param weight K-D tensor with shape [r_0, r_1, ..., r_{K-1}] where K == len(axis) and + * d_{axis_k} == r_k + * \param axis The axis to normalize over. + * \param epsilon The epsilon value to avoid division by zero. + * \param name The name of the operation. + * \param tag The tag to mark the operation. + * \return The normalized tensor, with the same shape as data. + */ +inline Tensor rms_norm(const Tensor& data, const Tensor& weight, const Array& axis, + double epsilon, std::string name = "T_rms_norm", + std::string tag = kInjective) { + const auto& data_type = data->dtype; + const auto& weight_type = weight.defined() ? weight->dtype : data_type; + ICHECK(data_type == weight_type) << "rms_norm: data and weight must have the same type"; + ICHECK(data_type == DataType::Float(32) || data_type == DataType::Float(16)) + << "rms_norm: only support float32 and float16 for now"; + bool is_float16 = data_type == DataType::Float(16); + + auto x = is_float16 ? cast(data, DataType::Float(32)) : data; + auto w = is_float16 ? cast(weight, DataType::Float(32)) : weight; + auto square = multiply(x, x); + auto square_sum = sum(square, axis, /*keepdims=*/false, /*atleast1d=*/true); + + auto ndim = data->shape.size(); + ICHECK_NE(ndim, 0) << "Cannot reduce a 0 dim Tensor"; + auto real_axis = GetRealAxis(static_cast(ndim), axis); + auto reduce_extent = make_const(data->dtype, 1); + for (int i : real_axis) { + reduce_extent *= data->shape[i]; + } + auto rms_norm_func = [&](const Array& indices) { + Array reduce_indices, non_reduce_indices; + for (int i = 0, n = static_cast(indices.size()); i < n; ++i) { + if (std::find(real_axis.begin(), real_axis.end(), i) != real_axis.end()) { + reduce_indices.push_back(indices[i]); + } else { + non_reduce_indices.push_back(indices[i]); + } + } + auto output = + x(indices) * w(reduce_indices) * + tvm::rsqrt(square_sum(non_reduce_indices) / reduce_extent + make_const(data_type, epsilon)); + return output; + }; + auto rms_norm = tvm::te::compute(data->shape, rms_norm_func, name, tag); + return is_float16 ? cast(rms_norm, DataType::Float(16)) : rms_norm; +} + +} // namespace nn +} // namespace topi +} // namespace tvm + +#endif // TVM_TOPI_NN_RMS_NORM_H_ diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py index be2f9685017b..091666a84ec7 100644 --- a/python/tvm/relax/op/nn/nn.py +++ b/python/tvm/relax/op/nn/nn.py @@ -902,6 +902,46 @@ def group_norm( ) +def rms_norm( + data: Expr, + weight: Expr, + axes: Union[int, List[int]], + epsilon: float = 1e-5, +) -> Expr: + r""" + Root mean square normalization (Biao Zhang and et al., 2019). + Applies root mean square normalization to the n-dimensional input array. + This operator takes an n-dimensional input array and normalizes + the input using the given axis: + + .. math:: + + out = \frac{data}{\sqrt{mean(data, axis)+\epsilon}} * weight + + Parameters + ---------- + data : relax.Expr + Input to which rms_norm will be applied. + + weight : relax.Expr + The scale factor. + + axes : Union[int, List[int]] + The axes that along which the normalization is applied. + + epsilon : float + Small float added to variance to avoid dividing by zero. + + Returns + ------- + result : relax.Expr + The computed result. + """ + if isinstance(axes, int): + axes = [axes] + return _ffi_api.rms_norm(data, weight, axes, epsilon) # type: ignore + + def dropout(data: Expr, rate: float = 0.5) -> Expr: """Applies the dropout operation to the input tensor. diff --git a/python/tvm/relax/transform/legalize_ops/nn.py b/python/tvm/relax/transform/legalize_ops/nn.py index 514c3d07822d..fca1da212606 100644 --- a/python/tvm/relax/transform/legalize_ops/nn.py +++ b/python/tvm/relax/transform/legalize_ops/nn.py @@ -329,6 +329,17 @@ def _nn_group_norm(bb: BlockBuilder, call: Call) -> Expr: ) +@register_legalize("relax.nn.rms_norm") +def _nn_layer_norm(bb: BlockBuilder, call: Call) -> Expr: + return bb.call_te( + topi.nn.rms_norm, + call.args[0], + call.args[1], + axis=call.attrs.axes, + epsilon=call.attrs.epsilon, + ) + + @register_legalize("relax.nn.dropout") def _nn_dropout(bb: BlockBuilder, call: Call) -> Expr: logging.info("Dropout is handled by frontend translator at this moment and is not legalized.") diff --git a/python/tvm/topi/nn/__init__.py b/python/tvm/topi/nn/__init__.py index d65c5c45c7e0..2c549cc5b9cf 100644 --- a/python/tvm/topi/nn/__init__.py +++ b/python/tvm/topi/nn/__init__.py @@ -41,6 +41,7 @@ from .instance_norm import instance_norm from .layer_norm import layer_norm from .group_norm import group_norm +from .rms_norm import rms_norm from .local_response_norm import * from .bitserial_conv2d import * from .bitserial_dense import * diff --git a/python/tvm/topi/nn/rms_norm.py b/python/tvm/topi/nn/rms_norm.py new file mode 100644 index 000000000000..651ff361bfb9 --- /dev/null +++ b/python/tvm/topi/nn/rms_norm.py @@ -0,0 +1,45 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Root mean square normalization operator.""" +from .. import cpp + + +def rms_norm(data, weight, axis, epsilon=1e-5): + """Root mean square normalization operator. + It accepts fp16 and fp32 as input data type. It will cast the input to fp32 + to perform the computation. The output will have the same data type as input. + + Parameters + ---------- + data : tvm.te.Tensor + N-D with shape (d_0, d_1, ..., d_{N-1}) + + weight: tvm.te.Tensor + K-D with shape (r_0, r_1, ..., r_{K-1}) where K == len(axis) and d_{axis_k} == r_k + + axis : list of int + Axis over the normalization applied + + epsilon : float + The epsilon value to avoid division by zero. + + Returns + ------- + result : tvm.te.Tensor + N-D with shape (d_0, d_1, ..., d_{N-1}) + """ + return cpp.nn.rms_norm(data, weight, axis, epsilon) diff --git a/python/tvm/topi/testing/__init__.py b/python/tvm/topi/testing/__init__.py index d950a20c0559..093f84d99bd3 100644 --- a/python/tvm/topi/testing/__init__.py +++ b/python/tvm/topi/testing/__init__.py @@ -46,6 +46,7 @@ from .instance_norm_python import instance_norm_python from .layer_norm_python import layer_norm_python from .group_norm_python import group_norm_python +from .rms_norm_python import rms_norm_python from .lrn_python import lrn_python from .l2_normalize_python import l2_normalize_python from .gather_python import gather_python diff --git a/python/tvm/topi/testing/rms_norm_python.py b/python/tvm/topi/testing/rms_norm_python.py new file mode 100644 index 000000000000..0273b419413c --- /dev/null +++ b/python/tvm/topi/testing/rms_norm_python.py @@ -0,0 +1,48 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals +"""Root mean square normalization in python""" +import numpy as np + + +def rms_norm_python(data, weight, axis, epsilon=1e-5): + """Root mean square normalization operator in Python. + + Parameters + ---------- + data : numpy.ndarray + N-D with shape (d_0, d_1, ..., d_{N-1}) + + weight: numpy.ndarray + K-D with shape (r_0, r_1, ..., r_{K-1}) where K == len(axis) and d_{axis_k} == r_k + + axis : int or tuple of ints + Axis over the normalization applied + + epsilon : float + The epsilon value to avoid division by zero. + + Returns + ------- + result : np.ndarray + N-D with shape (d_0, d_1, ..., d_{N-1}) + """ + old_dtype = data.dtype + data = data.astype("float32") + square_mean = np.mean(np.square(data), axis, keepdims=True) + result = data * weight / np.sqrt(square_mean + epsilon) + return result.astype(old_dtype) diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc index 0e7a957a6914..9baf6839d3ce 100644 --- a/src/relax/op/nn/nn.cc +++ b/src/relax/op/nn/nn.cc @@ -417,6 +417,65 @@ TVM_REGISTER_OP("relax.nn.group_norm") .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", Bool(true)); +/* relax.nn.rms_norm */ +TVM_REGISTER_NODE_TYPE(RMSNormAttrs); + +Expr rms_norm(Expr data, Expr weight, Array axes, double epsilon) { + ObjectPtr attrs = make_object(); + attrs->axes = std::move(axes); + attrs->epsilon = epsilon; + + static const Op& op = Op::Get("relax.nn.rms_norm"); + return Call(op, {std::move(data), std::move(weight)}, Attrs{attrs}, {}); +} + +TVM_REGISTER_GLOBAL("relax.op.nn.rms_norm").set_body_typed(rms_norm); + +StructInfo InferStructInfoRMSNorm(const Call& call, const BlockBuilder& ctx) { + Array input_sinfo = GetInputTensorStructInfo(call, ctx); + + const auto* attrs = call->attrs.as(); + bool unknown_shape = NormCheckDtypeAndShape(call, ctx, input_sinfo, attrs->axes); + + return unknown_shape ? TensorStructInfo(input_sinfo[0]->dtype, input_sinfo[0]->ndim) + : input_sinfo[0]; +} + +InferLayoutOutput InferLayoutRMSNorm(const Call& call, + const Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { + ICHECK(NoDesiredLayout(call, desired_layouts)); + std::vector initial_layouts; + for (size_t i = 0; i < 3; ++i) { + const auto* tensor_sinfo = GetStructInfoAs(call->args[i]); + ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; + ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support known ndim"; + initial_layouts.push_back(InitialLayoutDecision(tensor_sinfo->ndim)); + } + const auto* attrs = call->attrs.as(); + ICHECK(attrs) << "Invalid Call"; + + LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]); + ObjectPtr new_attrs = make_object(*attrs); + std::vector new_axis; + for (const auto& axis : attrs->axes) { + new_axis.push_back(FindAxis(layout->layout, axis->value)); + } + new_attrs->axes = std::move(new_axis); + return InferLayoutOutput({layout, initial_layouts[1], initial_layouts[2]}, {layout}, + Attrs(new_attrs)); +} + +TVM_REGISTER_OP("relax.nn.rms_norm") + .set_attrs_type() + .set_num_inputs(2) + .add_argument("data", "Tensor", "Input to which batch_norm will be applied.") + .add_argument("weight", "Tensor", "Input to which batch_norm will be applied.") + .set_attr("FInferStructInfo", InferStructInfoRMSNorm) + .set_attr("FRelaxInferLayout", InferLayoutRMSNorm) + .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) + .set_attr("FPurity", Bool(true)); + /* relax.nn.dropout */ TVM_REGISTER_NODE_TYPE(DropoutAttrs); diff --git a/src/relax/op/nn/nn.h b/src/relax/op/nn/nn.h index 38e605bb0b83..557f79dbac0d 100644 --- a/src/relax/op/nn/nn.h +++ b/src/relax/op/nn/nn.h @@ -75,6 +75,9 @@ Expr layer_norm(Expr data, Expr gamma, Expr beta, Array axes, double ep Expr group_norm(Expr data, Expr gamma, Expr beta, int num_groups, int channel_axis, Array axes, double epsilon, bool center, bool scale); +/*! \brief Compute root mean square normalization. */ +Expr rms_norm(Expr data, Expr weight, Array axes, double epsilon); + /*! * \brief Applies the dropout operation to the input tensor. * \param data The input data to the operator. diff --git a/src/topi/nn.cc b/src/topi/nn.cc index 58b962da6afa..9ce329b20637 100644 --- a/src/topi/nn.cc +++ b/src/topi/nn.cc @@ -35,6 +35,7 @@ #include #include #include +#include #include namespace tvm { @@ -176,5 +177,10 @@ TVM_REGISTER_GLOBAL("topi.nn.instance_norm").set_body([](TVMArgs args, TVMRetVal *rv = nn::instance_norm(args[0], args[1], args[2], args[3], static_cast(args[4])); }); +/* Ops from nn/rms_norm.h */ +TVM_REGISTER_GLOBAL("topi.nn.rms_norm").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = nn::rms_norm(args[0], args[1], args[2], static_cast(args[3])); +}); + } // namespace topi } // namespace tvm diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py b/tests/python/relax/test_transform_legalize_ops_nn.py index cbbacbabda5e..f8557fc8d0b4 100644 --- a/tests/python/relax/test_transform_legalize_ops_nn.py +++ b/tests/python/relax/test_transform_legalize_ops_nn.py @@ -2669,6 +2669,266 @@ def main(s: R.Shape(["c"]), x: R.Tensor(("n", "4 * c", "h", "w"), dtype="float32 tvm.ir.assert_structural_equal(mod, Expected) +def test_rms_norm(): + # fmt: off + @tvm.script.ir_module + class RMSNorm: + @R.function + def main(x: R.Tensor((2, 3, 4, 5), "float32"), weight: R.Tensor((4, 5), "float32")) -> R.Tensor((2, 3, 4, 5), "float32"): + gv: R.Tensor((2, 3, 4, 5), "float32") = R.nn.rms_norm(x, weight, axes=[-2, -1]) + return gv + + @tvm.script.ir_module + class Expected: + @T.prim_func + def rms_norm( + A: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32"), + B: T.Buffer((T.int64(4), T.int64(5)), "float32"), + T_rms_norm: T.Buffer( + (T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32" + ), + ): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + T_multiply = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5))) + T_multiply_red = T.alloc_buffer((T.int64(2), T.int64(3))) + for ax0, ax1, ax2, ax3 in T.grid( + T.int64(2), T.int64(3), T.int64(4), T.int64(5) + ): + with T.block("T_multiply"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3]) + T.writes(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3]) + T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] = ( + A[v_ax0, v_ax1, v_ax2, v_ax3] * A[v_ax0, v_ax1, v_ax2, v_ax3] + ) + for ax0, ax1, k2, k3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): + with T.block("T_multiply_red"): + v_ax0, v_ax1, v_k2, v_k3 = T.axis.remap("SSRR", [ax0, ax1, k2, k3]) + T.reads(T_multiply[v_ax0, v_ax1, v_k2, v_k3]) + T.writes(T_multiply_red[v_ax0, v_ax1]) + with T.init(): + T_multiply_red[v_ax0, v_ax1] = T.float32(0) + T_multiply_red[v_ax0, v_ax1] = ( + T_multiply_red[v_ax0, v_ax1] + T_multiply[v_ax0, v_ax1, v_k2, v_k3] + ) + for ax0, ax1, ax2, ax3 in T.grid( + T.int64(2), T.int64(3), T.int64(4), T.int64(5) + ): + with T.block("T_rms_norm"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads( + A[v_ax0, v_ax1, v_ax2, v_ax3], + B[v_ax2, v_ax3], + T_multiply_red[v_ax0, v_ax1], + ) + T.writes(T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3]) + T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3] = ( + A[v_ax0, v_ax1, v_ax2, v_ax3] + * B[v_ax2, v_ax3] + * T.rsqrt( + T_multiply_red[v_ax0, v_ax1] * T.float32(0.05) + + T.float32(1e-05) + ) + ) + + @R.function + def main( + x: R.Tensor((2, 3, 4, 5), dtype="float32"), + weight: R.Tensor((4, 5), dtype="float32"), + ) -> R.Tensor((2, 3, 4, 5), dtype="float32"): + cls = Expected + gv = R.call_tir( + cls.rms_norm, (x, weight), out_sinfo=R.Tensor((2, 3, 4, 5), dtype="float32") + ) + return gv + # fmt: on + mod = LegalizeOps()(RMSNorm) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_rms_norm_fp16(): + # fmt: off + @tvm.script.ir_module + class RMSNorm: + @R.function + def main(x: R.Tensor((2, 3, 4, 5), "float16"), weight: R.Tensor((4, 5), "float16")) -> R.Tensor((2, 3, 4, 5), "float16"): + gv: R.Tensor((2, 3, 4, 5), "float16") = R.nn.rms_norm(x, weight, axes=[-2, -1]) + return gv + + @tvm.script.ir_module + class Expected: + @T.prim_func + def rms_norm( + A: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float16"), + B: T.Buffer((T.int64(4), T.int64(5)), "float16"), + T_cast: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float16"), + ): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + T_cast_1 = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5))) + T_cast_2 = T.alloc_buffer((T.int64(4), T.int64(5))) + T_multiply = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5))) + T_multiply_red = T.alloc_buffer((T.int64(2), T.int64(3))) + T_rms_norm = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5))) + for ax0, ax1, ax2, ax3 in T.grid( + T.int64(2), T.int64(3), T.int64(4), T.int64(5) + ): + with T.block("T_cast"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3]) + T.writes(T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3]) + T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3] = T.Cast( + "float32", A[v_ax0, v_ax1, v_ax2, v_ax3] + ) + for ax0, ax1 in T.grid(T.int64(4), T.int64(5)): + with T.block("T_cast_1"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(B[v_ax0, v_ax1]) + T.writes(T_cast_2[v_ax0, v_ax1]) + T_cast_2[v_ax0, v_ax1] = T.Cast("float32", B[v_ax0, v_ax1]) + for ax0, ax1, ax2, ax3 in T.grid( + T.int64(2), T.int64(3), T.int64(4), T.int64(5) + ): + with T.block("T_multiply"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3]) + T.writes(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3]) + T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] = ( + T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3] + * T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3] + ) + for ax0, ax1, k2, k3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): + with T.block("T_multiply_red"): + v_ax0, v_ax1, v_k2, v_k3 = T.axis.remap("SSRR", [ax0, ax1, k2, k3]) + T.reads(T_multiply[v_ax0, v_ax1, v_k2, v_k3]) + T.writes(T_multiply_red[v_ax0, v_ax1]) + with T.init(): + T_multiply_red[v_ax0, v_ax1] = T.float32(0) + T_multiply_red[v_ax0, v_ax1] = ( + T_multiply_red[v_ax0, v_ax1] + T_multiply[v_ax0, v_ax1, v_k2, v_k3] + ) + for ax0, ax1, ax2, ax3 in T.grid( + T.int64(2), T.int64(3), T.int64(4), T.int64(5) + ): + with T.block("T_rms_norm"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads( + T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3], + T_cast_2[v_ax2, v_ax3], + T_multiply_red[v_ax0, v_ax1], + ) + T.writes(T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3]) + T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3] = ( + T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3] + * T_cast_2[v_ax2, v_ax3] + * T.rsqrt( + T_multiply_red[v_ax0, v_ax1] + / T.Cast("float32", T.float16(4) * T.float16(5)) + + T.float32(1e-05) + ) + ) + for ax0, ax1, ax2, ax3 in T.grid( + T.int64(2), T.int64(3), T.int64(4), T.int64(5) + ): + with T.block("T_cast_2"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3]) + T.writes(T_cast[v_ax0, v_ax1, v_ax2, v_ax3]) + T_cast[v_ax0, v_ax1, v_ax2, v_ax3] = T.Cast( + "float16", T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3] + ) + + @R.function + def main( + x: R.Tensor((2, 3, 4, 5), dtype="float16"), + weight: R.Tensor((4, 5), dtype="float16"), + ) -> R.Tensor((2, 3, 4, 5), dtype="float16"): + cls = Expected + gv = R.call_tir( + cls.rms_norm, (x, weight), out_sinfo=R.Tensor((2, 3, 4, 5), dtype="float16") + ) + return gv + # fmt: on + mod = LegalizeOps()(RMSNorm) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_rms_norm_symbolic(): + # fmt: off + @tvm.script.ir_module + class RMSNorm: + @R.function + def main(x: R.Tensor(("n", "s", "f"), "float32"), weight: R.Tensor(("s", "f"), "float32")) -> R.Tensor(("n", "s", "f"), "float32"): + n = T.int64() + s = T.int64() + f = T.int64() + gv: R.Tensor((n, s, f), "float32") = R.nn.rms_norm(x, weight, axes=[1, 2]) + return gv + + @tvm.script.ir_module + class Expected: + @T.prim_func + def rms_norm(var_A: T.handle, var_B: T.handle, var_T_rms_norm: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n, s, f = T.int64(), T.int64(), T.int64() + A = T.match_buffer(var_A, (n, s, f)) + B = T.match_buffer(var_B, (s, f)) + T_rms_norm = T.match_buffer(var_T_rms_norm, (n, s, f)) + # with T.block("root"): + T_multiply = T.alloc_buffer((n, s, f)) + T_multiply_red = T.alloc_buffer((n,)) + for ax0, ax1, ax2 in T.grid(n, s, f): + with T.block("T_multiply"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(A[v_ax0, v_ax1, v_ax2]) + T.writes(T_multiply[v_ax0, v_ax1, v_ax2]) + T_multiply[v_ax0, v_ax1, v_ax2] = ( + A[v_ax0, v_ax1, v_ax2] * A[v_ax0, v_ax1, v_ax2] + ) + for ax0, k1, k2 in T.grid(n, s, f): + with T.block("T_multiply_red"): + v_ax0, v_k1, v_k2 = T.axis.remap("SRR", [ax0, k1, k2]) + T.reads(T_multiply[v_ax0, v_k1, v_k2]) + T.writes(T_multiply_red[v_ax0]) + with T.init(): + T_multiply_red[v_ax0] = T.float32(0) + T_multiply_red[v_ax0] = ( + T_multiply_red[v_ax0] + T_multiply[v_ax0, v_k1, v_k2] + ) + for ax0, ax1, ax2 in T.grid(n, s, f): + with T.block("T_rms_norm"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(A[v_ax0, v_ax1, v_ax2], B[v_ax1, v_ax2], T_multiply_red[v_ax0]) + T.writes(T_rms_norm[v_ax0, v_ax1, v_ax2]) + T_rms_norm[v_ax0, v_ax1, v_ax2] = ( + A[v_ax0, v_ax1, v_ax2] + * B[v_ax1, v_ax2] + * T.rsqrt( + T_multiply_red[v_ax0] + / (T.Cast("float32", s) * T.Cast("float32", f)) + + T.float32(1e-05) + ) + ) + + @R.function + def main( + x: R.Tensor(("n", "s", "f"), dtype="float32"), + weight: R.Tensor(("s", "f"), dtype="float32"), + ) -> R.Tensor(("n", "s", "f"), dtype="float32"): + n = T.int64() + s = T.int64() + f = T.int64() + cls = Expected + gv = R.call_tir( + cls.rms_norm, (x, weight), out_sinfo=R.Tensor((n, s, f), dtype="float32") + ) + return gv + # fmt: on + mod = LegalizeOps()(RMSNorm) + tvm.ir.assert_structural_equal(mod, Expected) + + def test_attention(): # fmt: off @tvm.script.ir_module diff --git a/tests/python/topi/python/test_topi_rms_norm.py b/tests/python/topi/python/test_topi_rms_norm.py new file mode 100644 index 000000000000..a30c5bbc97f8 --- /dev/null +++ b/tests/python/topi/python/test_topi_rms_norm.py @@ -0,0 +1,60 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Test code for rms_norm.""" +import numpy as np +import pytest +import tvm +from tvm import te +from tvm import topi +from tvm.topi.utils import get_const_tuple +import tvm.topi.testing + +import tvm.testing + + +_rms_norm_schedule = { + "generic": topi.generic.schedule_injective, +} + + +# only test on llvm because schedule is missing +@tvm.testing.parametrize_targets("llvm") +@pytest.mark.parametrize("shape,axis", [([4, 16], (1,)), ([4, 16, 16], (1, 2))]) +@pytest.mark.parametrize("dtype", ["float32", "float16"]) +def test_rms_norm(target, dev, shape, axis, dtype, episilon=1e-5, rtol=5e-4, atol=5e-4): + data = te.placeholder(shape, dtype=dtype, name="data") + scale_shape = [shape[dim] for dim in axis] + weight = te.placeholder(scale_shape, dtype=dtype, name="weight") + B = topi.nn.rms_norm(data, weight, axis, episilon) + + data_np = np.random.uniform(size=shape).astype(dtype) + weight_np = np.random.uniform(size=scale_shape).astype(dtype) + b_np = tvm.topi.testing.rms_norm_python(data_np, weight_np, axis, episilon) + + with tvm.target.Target(target): + s_func = tvm.topi.testing.dispatch(target, _rms_norm_schedule) + s = s_func([B]) + data_tvm = tvm.nd.array(data_np, dev) + weight_tvm = tvm.nd.array(weight_np, dev) + b_tvm = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), dev) + f = tvm.build(s, [data, weight, B], target) + f(data_tvm, weight_tvm, b_tvm) + tvm.testing.assert_allclose(b_tvm.numpy(), b_np, rtol=rtol, atol=atol) + + +if __name__ == "__main__": + tvm.testing.main() From 182cd827f875ad1b95dd34488244d110d91935b6 Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Fri, 14 Jul 2023 08:12:51 +0000 Subject: [PATCH 2/2] fix duplicate function name --- python/tvm/relax/transform/legalize_ops/nn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relax/transform/legalize_ops/nn.py b/python/tvm/relax/transform/legalize_ops/nn.py index fca1da212606..75d7a2bc2019 100644 --- a/python/tvm/relax/transform/legalize_ops/nn.py +++ b/python/tvm/relax/transform/legalize_ops/nn.py @@ -330,7 +330,7 @@ def _nn_group_norm(bb: BlockBuilder, call: Call) -> Expr: @register_legalize("relax.nn.rms_norm") -def _nn_layer_norm(bb: BlockBuilder, call: Call) -> Expr: +def _nn_rms_norm(bb: BlockBuilder, call: Call) -> Expr: return bb.call_te( topi.nn.rms_norm, call.args[0],