Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 117 additions & 0 deletions include/tvm/topi/nn/layer_norm.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \brief layer normalization op constructions
* \file nn/layer_norm.h
*/
#ifndef TVM_TOPI_NN_LAYER_NORM_H_
#define TVM_TOPI_NN_LAYER_NORM_H_

#include <tvm/te/operation.h>
#include <tvm/topi/tags.h>

#include <string>

namespace tvm {
namespace topi {
namespace nn {

using namespace tvm::te;

/*!
* \brief Layer normalization.
* \param data N-D tensor with shape [d_0, d_1, ..., d_{N-1}]
* \param gamma K-D tensor with shape [r_0, r_1, ..., r_{K-1}] where K == len(axis) and
* d_{axis_k} == r_k
* \param beta Optional, K-D tensor with shape [r_0, r_1, ..., r_{K-1}] where
* d_{axis_k} == r_k
* \param axis The axis to normalize over.
* \param epsilon The epsilon value to avoid division by zero.
* \param name The name of the operation.
* \param tag The tag to mark the operation.
* \return The normalized tensor, with the same shape as data.
*/
inline Tensor layer_norm(const Tensor& data, const Tensor& gamma, const Tensor& beta,
const Array<Integer>& axis, double epsilon,
std::string name = "T_layer_norm", std::string tag = kInjective) {
// sum x and x^2
auto ndim = data->shape.size();
ICHECK_NE(ndim, 0) << "Cannot reduce a 0 dim Tensor";
auto real_axis = GetRealAxis(static_cast<int>(ndim), axis);
auto reduce_axes = MakeReduceAxes(real_axis, data);
auto target_shape =
MakeReduceTargetShape(real_axis, data, /*keepdims=*/false, /*atleast1d=*/true);
auto func = MakeTupleSumReducer();

auto compute = [ndim, &real_axis, &reduce_axes, &func, &data](const Array<Var>& indices) {
Array<PrimExpr> eval_range;
int arg_counter = 0;
int red_counter = 0;

for (size_t i = 0; i < ndim; ++i) {
if (std::find(real_axis.begin(), real_axis.end(), i) != real_axis.end()) {
// real_axis contains i
eval_range.push_back(reduce_axes[red_counter]);
red_counter++;
} else {
eval_range.push_back(indices[arg_counter]);
arg_counter++;
}
}
auto square = [](const PrimExpr& x) { return x * x; };
return func({data(eval_range), square(data(eval_range))}, reduce_axes, nullptr);
};

auto temp_x_x2 =
tvm::te::compute(target_shape, compute, data->op->name + "_red_temp", kCommReduce);

auto temp_x = temp_x_x2[0];
auto temp_x2 = temp_x_x2[1];

auto reduce_extent = make_const(data->dtype, 1);
for (int i : real_axis) {
reduce_extent *= data->shape[i];
}
auto layer_norm_func = [&](const Array<Var>& indices) {
Array<Var> reduce_indices, non_reduce_indices;
for (int i = 0, n = static_cast<int>(indices.size()); i < n; ++i) {
if (std::find(real_axis.begin(), real_axis.end(), i) != real_axis.end()) {
reduce_indices.push_back(indices[i]);
} else {
non_reduce_indices.push_back(indices[i]);
}
}
auto mean = temp_x(non_reduce_indices) / reduce_extent;
auto var = temp_x2(non_reduce_indices) / reduce_extent - mean * mean;
auto layer_norm = (data(indices) - mean) * tvm::rsqrt(var + make_const(var->dtype, epsilon));
layer_norm = topi::multiply(layer_norm, gamma(reduce_indices));
if (beta.defined()) {
layer_norm = topi::add(layer_norm, beta(reduce_indices));
}
return layer_norm;
};
return tvm::te::compute(data->shape, layer_norm_func, name, tag);
}

} // namespace nn
} // namespace topi
} // namespace tvm

#endif // TVM_TOPI_NN_LAYER_NORM_H_
23 changes: 23 additions & 0 deletions include/tvm/topi/reduction.h
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,29 @@ inline Tensor prod(const Tensor& data, const Array<Integer>& axis, bool keepdims
return CommReduce(data, axis, ProdOp, keepdims, atleast1d);
}

/*!
* \brief Create communitive reducer summing over tuples
*/
inline FCommReduce MakeTupleSumReducer() {
auto fcombine = [](Array<Var> lhs, Array<Var> rhs) {
Array<PrimExpr> result;
ICHECK_EQ(lhs.size(), rhs.size());
result.reserve(lhs.size());
for (size_t i = 0; i < lhs.size(); ++i) {
result.push_back(lhs[i] + rhs[i]);
}
return result;
};
auto fidentity = [](std::vector<DataType> types) {
Array<PrimExpr> result;
for (size_t i = 0; i < types.size(); ++i) {
result.push_back(tvm::tir::make_const(types[i], 0));
}
return result;
};
return MakeCommReducer(fcombine, fidentity, "tuple_sum");
}

} // namespace topi
} // namespace tvm
#endif // TVM_TOPI_REDUCTION_H_
1 change: 1 addition & 0 deletions python/tvm/topi/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from .bnn import *
from .qnn import *
from .upsampling import *
from .layer_norm import layer_norm
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about importing * 👀? Since I see all other imports import *.

Suggested change
from .layer_norm import layer_norm
from .layer_norm import *

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wildcard importing is actually not a good idea though lol

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed, so I avoid using wildcard here. perhaps we should clean up this file in the future

from .local_response_norm import *
from .bitserial_conv2d import *
from .bitserial_dense import *
Expand Down
46 changes: 46 additions & 0 deletions python/tvm/topi/nn/layer_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Layer normalization operator."""
from .. import cpp


def layer_norm(data, gamma, beta, axis, epsilon=1e-5):
"""Layer normalization operator.

Parameters
----------
data : tvm.te.Tensor
N-D with shape (d_0, d_1, ..., d_{N-1})

gamma: tvm.te.Tensor
K-D with shape (r_0, r_1, ..., r_{K-1}) where K == len(axis) and d_{axis_k} == r_k

beta: tvm.te.Tensor
Optional, K-D with shape (r_0, r_1, ..., r_{K-1}) where K == len(axis) and d_{axis_k} == r_k

axis : list of int
Axis over the normalization applied

epsilon : float
The epsilon value to avoid division by zero.

Returns
-------
result : tvm.te.Tensor
N-D with shape (d_0, d_1, ..., d_{N-1})
"""
return cpp.nn.layer_norm(data, gamma, beta, axis, epsilon)
1 change: 1 addition & 0 deletions python/tvm/topi/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from .reorg_python import reorg_python
from .roi_align_python import roi_align_nchw_python, roi_align_nhwc_python
from .roi_pool_python import roi_pool_nchw_python
from .layer_norm_python import layer_norm_python
from .lrn_python import lrn_python
from .l2_normalize_python import l2_normalize_python
from .gather_python import gather_python
Expand Down
53 changes: 53 additions & 0 deletions python/tvm/topi/testing/layer_norm_python.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals
"""Layer normalization in python"""
import numpy as np


def layer_norm_python(data, gamma, beta, axis, epsilon=1e-5):
"""Layer normalization operator in Python.

Parameters
----------
data : numpy.ndarray
N-D with shape (d_0, d_1, ..., d_{N-1})

gamma: numpy.ndarray
K-D with shape (r_0, r_1, ..., r_{K-1}) where K == len(axis) and d_{axis_k} == r_k

beta: numpy.ndarray
Optional, K-D with shape (r_0, r_1, ..., r_{K-1}) where K == len(axis) and d_{axis_k} == r_k

axis : int or tuple of ints
Axis over the normalization applied

epsilon : float
The epsilon value to avoid division by zero.

Returns
-------
result : np.ndarray
N-D with shape (d_0, d_1, ..., d_{N-1})
"""
mean = np.mean(data, axis, keepdims=True)
var = np.var(data, axis, keepdims=True)
result = (data - mean) / np.sqrt(var + epsilon)
result *= gamma
if beta is not None:
result += beta
return result
9 changes: 9 additions & 0 deletions src/tir/schedule/primitive/reduction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,15 @@ struct ReducerRegistry {
[](const Array<PrimExpr>& values) {
return Array<PrimExpr>{min_value(values[0]->dtype)};
}),
CreateReducerGetter(
/*n_buffers=*/2,
[](const Array<Var>& x, const Array<Var>& y) {
return Array<PrimExpr>{x[0] + y[0], x[1] + y[1]};
},
[](const Array<PrimExpr>& values) {
return Array<PrimExpr>{make_const(values[0]->dtype, 0),
make_const(values[1]->dtype, 0)};
}),
CreateReducerGetter(
/*n_buffers=*/2,
[](const Array<Var>& x, const Array<Var>& y) {
Expand Down
6 changes: 6 additions & 0 deletions src/topi/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include <tvm/topi/nn/dense.h>
#include <tvm/topi/nn/dilate.h>
#include <tvm/topi/nn/flatten.h>
#include <tvm/topi/nn/layer_norm.h>
#include <tvm/topi/nn/local_response_norm.h>
#include <tvm/topi/nn/mapping.h>
#include <tvm/topi/nn/pooling.h>
Expand Down Expand Up @@ -157,5 +158,10 @@ TVM_REGISTER_GLOBAL("topi.nn.binary_dense").set_body([](TVMArgs args, TVMRetValu
*rv = nn::binary_dense(args[0], args[1]);
});

/* Ops from nn/layer_norm.h */
TVM_REGISTER_GLOBAL("topi.nn.layer_norm").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = nn::layer_norm(args[0], args[1], args[2], args[3], static_cast<double>(args[4]));
});

} // namespace topi
} // namespace tvm
62 changes: 62 additions & 0 deletions tests/python/topi/python/test_topi_layer_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Test code for layer_norm."""
import numpy as np
import pytest
import tvm
from tvm import te
from tvm import topi
from tvm.topi.utils import get_const_tuple
import tvm.topi.testing

import tvm.testing


_layer_norm_schedule = {
"generic": topi.generic.schedule_injective,
}


# only test on llvm because schedule is missing
@tvm.testing.parametrize_targets("llvm")
@pytest.mark.parametrize("shape,axis", [([4, 16], (1,)), ([4, 16, 16], (1, 2))])
def test_layer_norm(target, dev, shape, axis, episilon=1e-5, dtype="float32", rtol=1e-5, atol=1e-5):
data = te.placeholder(shape, dtype=dtype, name="data")
scale_shape = [shape[dim] for dim in axis]
gamma = te.placeholder(scale_shape, dtype=dtype, name="gamma")
beta = te.placeholder(scale_shape, dtype=dtype, name="beta")
B = topi.nn.layer_norm(data, gamma, beta, axis, episilon)

data_np = np.random.uniform(size=shape).astype(dtype)
gamma_np = np.random.uniform(size=scale_shape).astype(dtype)
beta_np = np.random.uniform(size=scale_shape).astype(dtype)
b_np = tvm.topi.testing.layer_norm_python(data_np, gamma_np, beta_np, axis, episilon)

with tvm.target.Target(target):
s_func = tvm.topi.testing.dispatch(target, _layer_norm_schedule)
s = s_func([B])
data_tvm = tvm.nd.array(data_np, dev)
gamma_tvm = tvm.nd.array(gamma_np, dev)
beta_tvm = tvm.nd.array(beta_np, dev)
b_tvm = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), dev)
f = tvm.build(s, [data, gamma, beta, B], target)
f(data_tvm, gamma_tvm, beta_tvm, b_tvm)
tvm.testing.assert_allclose(b_tvm.asnumpy(), b_np, rtol=rtol, atol=atol)


if __name__ == "__main__":
tvm.testing.main()
Loading