Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions include/tvm/relax/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,17 @@ struct GroupNormAttrs : public tvm::AttrsNode<GroupNormAttrs> {
}
}; // struct GroupNormAttrs

/*! \brief Attributes used in rms_norm operator */
struct RMSNormAttrs : public tvm::AttrsNode<RMSNormAttrs> {
Array<Integer> axes;
double epsilon;

TVM_DECLARE_ATTRS(RMSNormAttrs, "relax.attrs.RMSNormAttrs") {
TVM_ATTR_FIELD(axes).describe("The axes that along which the normalization is applied.");
TVM_ATTR_FIELD(epsilon).describe("Small float added to variance to avoid dividing by zero");
}
}; // struct RMSNormAttrs

/*! \brief Attributes used in nll_loss operator */
struct NLLLossAttrs : public tvm::AttrsNode<NLLLossAttrs> {
String reduction;
Expand Down
94 changes: 94 additions & 0 deletions include/tvm/topi/nn/rms_norm.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \brief root mean square normalization op constructions
* \file nn/rms_norm.h
*/
#ifndef TVM_TOPI_NN_RMS_NORM_H_
#define TVM_TOPI_NN_RMS_NORM_H_

#include <tvm/te/operation.h>
#include <tvm/topi/reduction.h>
#include <tvm/topi/tags.h>

#include <string>

namespace tvm {
namespace topi {
namespace nn {

using namespace tvm::te;

/*!
* \brief Root mean square normalization.
* \param data N-D tensor with shape [d_0, d_1, ..., d_{N-1}]
* \param weight K-D tensor with shape [r_0, r_1, ..., r_{K-1}] where K == len(axis) and
* d_{axis_k} == r_k
* \param axis The axis to normalize over.
* \param epsilon The epsilon value to avoid division by zero.
* \param name The name of the operation.
* \param tag The tag to mark the operation.
* \return The normalized tensor, with the same shape as data.
*/
inline Tensor rms_norm(const Tensor& data, const Tensor& weight, const Array<Integer>& axis,
double epsilon, std::string name = "T_rms_norm",
std::string tag = kInjective) {
const auto& data_type = data->dtype;
const auto& weight_type = weight.defined() ? weight->dtype : data_type;
ICHECK(data_type == weight_type) << "rms_norm: data and weight must have the same type";
ICHECK(data_type == DataType::Float(32) || data_type == DataType::Float(16))
<< "rms_norm: only support float32 and float16 for now";
bool is_float16 = data_type == DataType::Float(16);

auto x = is_float16 ? cast(data, DataType::Float(32)) : data;
auto w = is_float16 ? cast(weight, DataType::Float(32)) : weight;
auto square = multiply(x, x);
auto square_sum = sum(square, axis, /*keepdims=*/false, /*atleast1d=*/true);

auto ndim = data->shape.size();
ICHECK_NE(ndim, 0) << "Cannot reduce a 0 dim Tensor";
auto real_axis = GetRealAxis(static_cast<int>(ndim), axis);
auto reduce_extent = make_const(data->dtype, 1);
for (int i : real_axis) {
reduce_extent *= data->shape[i];
}
auto rms_norm_func = [&](const Array<Var>& indices) {
Array<Var> reduce_indices, non_reduce_indices;
for (int i = 0, n = static_cast<int>(indices.size()); i < n; ++i) {
if (std::find(real_axis.begin(), real_axis.end(), i) != real_axis.end()) {
reduce_indices.push_back(indices[i]);
} else {
non_reduce_indices.push_back(indices[i]);
}
}
auto output =
x(indices) * w(reduce_indices) *
tvm::rsqrt(square_sum(non_reduce_indices) / reduce_extent + make_const(data_type, epsilon));
return output;
};
auto rms_norm = tvm::te::compute(data->shape, rms_norm_func, name, tag);
return is_float16 ? cast(rms_norm, DataType::Float(16)) : rms_norm;
}

} // namespace nn
} // namespace topi
} // namespace tvm

#endif // TVM_TOPI_NN_RMS_NORM_H_
40 changes: 40 additions & 0 deletions python/tvm/relax/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -902,6 +902,46 @@ def group_norm(
)


def rms_norm(
data: Expr,
weight: Expr,
axes: Union[int, List[int]],
epsilon: float = 1e-5,
) -> Expr:
r"""
Root mean square normalization (Biao Zhang and et al., 2019).
Applies root mean square normalization to the n-dimensional input array.
This operator takes an n-dimensional input array and normalizes
the input using the given axis:

.. math::

out = \frac{data}{\sqrt{mean(data, axis)+\epsilon}} * weight

Parameters
----------
data : relax.Expr
Input to which rms_norm will be applied.

weight : relax.Expr
The scale factor.

axes : Union[int, List[int]]
The axes that along which the normalization is applied.

epsilon : float
Small float added to variance to avoid dividing by zero.

Returns
-------
result : relax.Expr
The computed result.
"""
if isinstance(axes, int):
axes = [axes]
return _ffi_api.rms_norm(data, weight, axes, epsilon) # type: ignore


def dropout(data: Expr, rate: float = 0.5) -> Expr:
"""Applies the dropout operation to the input tensor.

Expand Down
11 changes: 11 additions & 0 deletions python/tvm/relax/transform/legalize_ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,17 @@ def _nn_group_norm(bb: BlockBuilder, call: Call) -> Expr:
)


@register_legalize("relax.nn.rms_norm")
def _nn_rms_norm(bb: BlockBuilder, call: Call) -> Expr:
return bb.call_te(
topi.nn.rms_norm,
call.args[0],
call.args[1],
axis=call.attrs.axes,
epsilon=call.attrs.epsilon,
)


@register_legalize("relax.nn.dropout")
def _nn_dropout(bb: BlockBuilder, call: Call) -> Expr:
logging.info("Dropout is handled by frontend translator at this moment and is not legalized.")
Expand Down
1 change: 1 addition & 0 deletions python/tvm/topi/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from .instance_norm import instance_norm
from .layer_norm import layer_norm
from .group_norm import group_norm
from .rms_norm import rms_norm
from .local_response_norm import *
from .bitserial_conv2d import *
from .bitserial_dense import *
Expand Down
45 changes: 45 additions & 0 deletions python/tvm/topi/nn/rms_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Root mean square normalization operator."""
from .. import cpp


def rms_norm(data, weight, axis, epsilon=1e-5):
"""Root mean square normalization operator.
It accepts fp16 and fp32 as input data type. It will cast the input to fp32
to perform the computation. The output will have the same data type as input.

Parameters
----------
data : tvm.te.Tensor
N-D with shape (d_0, d_1, ..., d_{N-1})

weight: tvm.te.Tensor
K-D with shape (r_0, r_1, ..., r_{K-1}) where K == len(axis) and d_{axis_k} == r_k

axis : list of int
Axis over the normalization applied

epsilon : float
The epsilon value to avoid division by zero.

Returns
-------
result : tvm.te.Tensor
N-D with shape (d_0, d_1, ..., d_{N-1})
"""
return cpp.nn.rms_norm(data, weight, axis, epsilon)
1 change: 1 addition & 0 deletions python/tvm/topi/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from .instance_norm_python import instance_norm_python
from .layer_norm_python import layer_norm_python
from .group_norm_python import group_norm_python
from .rms_norm_python import rms_norm_python
from .lrn_python import lrn_python
from .l2_normalize_python import l2_normalize_python
from .gather_python import gather_python
Expand Down
48 changes: 48 additions & 0 deletions python/tvm/topi/testing/rms_norm_python.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals
"""Root mean square normalization in python"""
import numpy as np


def rms_norm_python(data, weight, axis, epsilon=1e-5):
"""Root mean square normalization operator in Python.

Parameters
----------
data : numpy.ndarray
N-D with shape (d_0, d_1, ..., d_{N-1})

weight: numpy.ndarray
K-D with shape (r_0, r_1, ..., r_{K-1}) where K == len(axis) and d_{axis_k} == r_k

axis : int or tuple of ints
Axis over the normalization applied

epsilon : float
The epsilon value to avoid division by zero.

Returns
-------
result : np.ndarray
N-D with shape (d_0, d_1, ..., d_{N-1})
"""
old_dtype = data.dtype
data = data.astype("float32")
square_mean = np.mean(np.square(data), axis, keepdims=True)
result = data * weight / np.sqrt(square_mean + epsilon)
return result.astype(old_dtype)
59 changes: 59 additions & 0 deletions src/relax/op/nn/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,65 @@ TVM_REGISTER_OP("relax.nn.group_norm")
.set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow)
.set_attr<Bool>("FPurity", Bool(true));

/* relax.nn.rms_norm */
TVM_REGISTER_NODE_TYPE(RMSNormAttrs);

Expr rms_norm(Expr data, Expr weight, Array<Integer> axes, double epsilon) {
ObjectPtr<RMSNormAttrs> attrs = make_object<RMSNormAttrs>();
attrs->axes = std::move(axes);
attrs->epsilon = epsilon;

static const Op& op = Op::Get("relax.nn.rms_norm");
return Call(op, {std::move(data), std::move(weight)}, Attrs{attrs}, {});
}

TVM_REGISTER_GLOBAL("relax.op.nn.rms_norm").set_body_typed(rms_norm);

StructInfo InferStructInfoRMSNorm(const Call& call, const BlockBuilder& ctx) {
Array<TensorStructInfo> input_sinfo = GetInputTensorStructInfo(call, ctx);

const auto* attrs = call->attrs.as<RMSNormAttrs>();
bool unknown_shape = NormCheckDtypeAndShape(call, ctx, input_sinfo, attrs->axes);

return unknown_shape ? TensorStructInfo(input_sinfo[0]->dtype, input_sinfo[0]->ndim)
: input_sinfo[0];
}

InferLayoutOutput InferLayoutRMSNorm(const Call& call,
const Map<String, Array<String>>& desired_layouts,
const VarLayoutMap& var_layout_map) {
ICHECK(NoDesiredLayout(call, desired_layouts));
std::vector<NLayout> initial_layouts;
for (size_t i = 0; i < 3; ++i) {
const auto* tensor_sinfo = GetStructInfoAs<TensorStructInfoNode>(call->args[i]);
ICHECK(tensor_sinfo != nullptr) << "Invalid Call";
ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support known ndim";
initial_layouts.push_back(InitialLayoutDecision(tensor_sinfo->ndim));
}
const auto* attrs = call->attrs.as<RMSNormAttrs>();
ICHECK(attrs) << "Invalid Call";

LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]);
ObjectPtr<RMSNormAttrs> new_attrs = make_object<RMSNormAttrs>(*attrs);
std::vector<Integer> new_axis;
for (const auto& axis : attrs->axes) {
new_axis.push_back(FindAxis(layout->layout, axis->value));
}
new_attrs->axes = std::move(new_axis);
return InferLayoutOutput({layout, initial_layouts[1], initial_layouts[2]}, {layout},
Attrs(new_attrs));
}

TVM_REGISTER_OP("relax.nn.rms_norm")
.set_attrs_type<RMSNormAttrs>()
.set_num_inputs(2)
.add_argument("data", "Tensor", "Input to which batch_norm will be applied.")
.add_argument("weight", "Tensor", "Input to which batch_norm will be applied.")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoRMSNorm)
.set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutRMSNorm)
.set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow)
.set_attr<Bool>("FPurity", Bool(true));

/* relax.nn.dropout */
TVM_REGISTER_NODE_TYPE(DropoutAttrs);

Expand Down
3 changes: 3 additions & 0 deletions src/relax/op/nn/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ Expr layer_norm(Expr data, Expr gamma, Expr beta, Array<Integer> axes, double ep
Expr group_norm(Expr data, Expr gamma, Expr beta, int num_groups, int channel_axis,
Array<Integer> axes, double epsilon, bool center, bool scale);

/*! \brief Compute root mean square normalization. */
Expr rms_norm(Expr data, Expr weight, Array<Integer> axes, double epsilon);

/*!
* \brief Applies the dropout operation to the input tensor.
* \param data The input data to the operator.
Expand Down
6 changes: 6 additions & 0 deletions src/topi/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#include <tvm/topi/nn/local_response_norm.h>
#include <tvm/topi/nn/mapping.h>
#include <tvm/topi/nn/pooling.h>
#include <tvm/topi/nn/rms_norm.h>
#include <tvm/topi/nn/softmax.h>

namespace tvm {
Expand Down Expand Up @@ -176,5 +177,10 @@ TVM_REGISTER_GLOBAL("topi.nn.instance_norm").set_body([](TVMArgs args, TVMRetVal
*rv = nn::instance_norm(args[0], args[1], args[2], args[3], static_cast<double>(args[4]));
});

/* Ops from nn/rms_norm.h */
TVM_REGISTER_GLOBAL("topi.nn.rms_norm").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = nn::rms_norm(args[0], args[1], args[2], static_cast<double>(args[3]));
});

} // namespace topi
} // namespace tvm
Loading