diff --git a/include/tvm/topi/nn/rms_norm.h b/include/tvm/topi/nn/rms_norm.h new file mode 100644 index 000000000000..44d38bae6d7a --- /dev/null +++ b/include/tvm/topi/nn/rms_norm.h @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \brief root mean square normalization op constructions + * \file nn/rms_norm.h + */ +#ifndef TVM_TOPI_NN_RMS_NORM_H_ +#define TVM_TOPI_NN_RMS_NORM_H_ + +#include +#include +#include + +#include + +namespace tvm { +namespace topi { +namespace nn { + +using namespace tvm::te; + +/*! + * \brief Root mean square normalization. + * \param data N-D tensor with shape [d_0, d_1, ..., d_{N-1}] + * \param weight K-D tensor with shape [r_0, r_1, ..., r_{K-1}] where K == len(axis) and + * d_{axis_k} == r_k + * \param bias Optional, K-D tensor with shape [r_0, r_1, ..., r_{K-1}] where + * d_{axis_k} == r_k + * \param axis The axis to normalize over. + * \param epsilon The epsilon value to avoid division by zero. + * \param name The name of the operation. + * \param tag The tag to mark the operation. + * \return The normalized tensor, with the same shape as data. + */ +inline Tensor rms_norm(const Tensor& data, const Tensor& weight, const Tensor& bias, + const Array& axis, double epsilon, std::string name = "T_rms_norm", + std::string tag = kInjective) { + const auto& data_type = data->dtype; + const auto& weight_type = weight.defined() ? weight->dtype : data_type; + ICHECK(data_type == weight_type) << "rms_norm: data and weight must have the same type"; + const auto& bias_type = bias.defined() ? bias->dtype : data_type; + ICHECK(data_type == bias_type) << "rms_norm: data and bias must have the same type"; + + auto square = multiply(data, data); + auto square_sum = sum(square, axis, /*keepdims=*/false, /*atleast1d=*/true); + + auto ndim = data->shape.size(); + ICHECK_NE(ndim, 0) << "Cannot reduce a 0 dim Tensor"; + auto real_axis = GetRealAxis(static_cast(ndim), axis); + auto reduce_extent = make_const(data->dtype, 1); + for (int i : real_axis) { + reduce_extent *= data->shape[i]; + } + auto rms_norm_func = [&](const Array& indices) { + Array reduce_indices, non_reduce_indices; + for (int i = 0, n = static_cast(indices.size()); i < n; ++i) { + if (std::find(real_axis.begin(), real_axis.end(), i) != real_axis.end()) { + reduce_indices.push_back(indices[i]); + } else { + non_reduce_indices.push_back(indices[i]); + } + } + auto output = + data(indices) * weight(reduce_indices) * + tvm::rsqrt(square_sum(non_reduce_indices) / reduce_extent + make_const(data_type, epsilon)); + if (bias.defined()) { + output += bias(reduce_indices); + } + return output; + }; + auto rms_norm = tvm::te::compute(data->shape, rms_norm_func, name, tag); + return rms_norm; +} + +} // namespace nn +} // namespace topi +} // namespace tvm + +#endif // TVM_TOPI_NN_RMS_NORM_H_ diff --git a/python/tvm/topi/nn/__init__.py b/python/tvm/topi/nn/__init__.py index d65c5c45c7e0..2c549cc5b9cf 100644 --- a/python/tvm/topi/nn/__init__.py +++ b/python/tvm/topi/nn/__init__.py @@ -41,6 +41,7 @@ from .instance_norm import instance_norm from .layer_norm import layer_norm from .group_norm import group_norm +from .rms_norm import rms_norm from .local_response_norm import * from .bitserial_conv2d import * from .bitserial_dense import * diff --git a/python/tvm/topi/nn/rms_norm.py b/python/tvm/topi/nn/rms_norm.py new file mode 100644 index 000000000000..f2f5a7e67487 --- /dev/null +++ b/python/tvm/topi/nn/rms_norm.py @@ -0,0 +1,46 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Root mean square normalization operator.""" +from .. import cpp + + +def rms_norm(data, weight, bias, axis, epsilon=1e-5): + """Root mean square normalization operator. The output will have the same data type as input. + + Parameters + ---------- + data : tvm.te.Tensor + N-D with shape (d_0, d_1, ..., d_{N-1}) + + weight: tvm.te.Tensor + K-D with shape (r_0, r_1, ..., r_{K-1}) where K == len(axis) and d_{axis_k} == r_k + + bias: tvm.te.Tensor + Optional, K-D with shape (r_0, r_1, ..., r_{K-1}) where K == len(axis) and d_{axis_k} == r_k + + axis : list of int + Axis over the normalization applied + + epsilon : float + The epsilon value to avoid division by zero. + + Returns + ------- + result : tvm.te.Tensor + N-D with shape (d_0, d_1, ..., d_{N-1}) + """ + return cpp.nn.rms_norm(data, weight, bias, axis, epsilon) diff --git a/python/tvm/topi/testing/__init__.py b/python/tvm/topi/testing/__init__.py index d950a20c0559..093f84d99bd3 100644 --- a/python/tvm/topi/testing/__init__.py +++ b/python/tvm/topi/testing/__init__.py @@ -46,6 +46,7 @@ from .instance_norm_python import instance_norm_python from .layer_norm_python import layer_norm_python from .group_norm_python import group_norm_python +from .rms_norm_python import rms_norm_python from .lrn_python import lrn_python from .l2_normalize_python import l2_normalize_python from .gather_python import gather_python diff --git a/python/tvm/topi/testing/rms_norm_python.py b/python/tvm/topi/testing/rms_norm_python.py new file mode 100644 index 000000000000..7fad5d57ce10 --- /dev/null +++ b/python/tvm/topi/testing/rms_norm_python.py @@ -0,0 +1,51 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals +"""Root mean square normalization in python""" +import numpy as np + + +def rms_norm_python(data, weight, bias, axis, epsilon=1e-5): + """Root mean square normalization operator in Python. + + Parameters + ---------- + data : numpy.ndarray + N-D with shape (d_0, d_1, ..., d_{N-1}) + + weight: numpy.ndarray + K-D with shape (r_0, r_1, ..., r_{K-1}) where K == len(axis) and d_{axis_k} == r_k + + bias: numpy.ndarray + Optional, K-D with shape (r_0, r_1, ..., r_{K-1}) where K == len(axis) and d_{axis_k} == r_k + + axis : int or tuple of ints + Axis over the normalization applied + + epsilon : float + The epsilon value to avoid division by zero. + + Returns + ------- + result : np.ndarray + N-D with shape (d_0, d_1, ..., d_{N-1}) + """ + square_mean = np.mean(np.square(data), axis, keepdims=True) + result = data * weight / np.sqrt(square_mean + epsilon) + if bias is not None: + result += bias + return result diff --git a/src/topi/nn.cc b/src/topi/nn.cc index 58b962da6afa..ba88f01c6850 100644 --- a/src/topi/nn.cc +++ b/src/topi/nn.cc @@ -35,6 +35,7 @@ #include #include #include +#include #include namespace tvm { @@ -176,5 +177,10 @@ TVM_REGISTER_GLOBAL("topi.nn.instance_norm").set_body([](TVMArgs args, TVMRetVal *rv = nn::instance_norm(args[0], args[1], args[2], args[3], static_cast(args[4])); }); +/* Ops from nn/rms_norm.h */ +TVM_REGISTER_GLOBAL("topi.nn.rms_norm").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = nn::rms_norm(args[0], args[1], args[2], args[3], static_cast(args[4])); +}); + } // namespace topi } // namespace tvm diff --git a/tests/python/topi/python/test_topi_rms_norm.py b/tests/python/topi/python/test_topi_rms_norm.py new file mode 100644 index 000000000000..35a1485afa6b --- /dev/null +++ b/tests/python/topi/python/test_topi_rms_norm.py @@ -0,0 +1,68 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Test code for rms_norm.""" +import numpy as np +import pytest +import tvm +from tvm import te +from tvm import topi +from tvm.topi.utils import get_const_tuple +import tvm.topi.testing + +import tvm.testing + + +_rms_norm_schedule = { + "generic": topi.generic.schedule_injective, +} + + +# only test on llvm because schedule is missing +@tvm.testing.parametrize_targets("llvm") +@pytest.mark.parametrize( + "shape,axis", [([4, 16], (1,)), ([4, 16, 16], (1, 2)), ([("a", 4), ("b", 16)], (1,))] +) +@pytest.mark.parametrize("dtype", ["float32", "float16"]) +def test_rms_norm(target, dev, shape, axis, dtype, episilon=1e-5, rtol=5e-3, atol=1e-4): + shape_te = [te.var(v[0]) if isinstance(v, tuple) else v for v in shape] + scale_shape_te = [shape_te[dim] for dim in axis] + data = te.placeholder(shape_te, dtype=dtype, name="data") + weight = te.placeholder(scale_shape_te, dtype=dtype, name="weight") + bias = te.placeholder(scale_shape_te, dtype=dtype, name="weight") + B = topi.nn.rms_norm(data, weight, bias, axis, episilon) + + shape_np = [v[1] if isinstance(v, tuple) else v for v in shape] + scale_shape_np = [shape_np[dim] for dim in axis] + data_np = np.random.uniform(size=shape_np).astype(dtype) + weight_np = np.random.uniform(size=scale_shape_np).astype(dtype) + bias_np = np.random.uniform(size=scale_shape_np).astype(dtype) + b_np = tvm.topi.testing.rms_norm_python(data_np, weight_np, bias_np, axis, episilon) + + with tvm.target.Target(target): + s_func = tvm.topi.testing.dispatch(target, _rms_norm_schedule) + s = s_func([B]) + data_tvm = tvm.nd.array(data_np, dev) + weight_tvm = tvm.nd.array(weight_np, dev) + bias_tvm = tvm.nd.array(bias_np, dev) + b_tvm = tvm.nd.array(np.zeros(shape_np, dtype=dtype), dev) + f = tvm.build(s, [data, weight, bias, B], target) + f(data_tvm, weight_tvm, bias_tvm, b_tvm) + tvm.testing.assert_allclose(b_tvm.numpy(), b_np, rtol=rtol, atol=atol) + + +if __name__ == "__main__": + tvm.testing.main()