-
Notifications
You must be signed in to change notification settings - Fork 3.8k
[TOPI] Add layer norm operator #12864
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,117 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one | ||
| * or more contributor license agreements. See the NOTICE file | ||
| * distributed with this work for additional information | ||
| * regarding copyright ownership. The ASF licenses this file | ||
| * to you under the Apache License, Version 2.0 (the | ||
| * "License"); you may not use this file except in compliance | ||
| * with the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, | ||
| * software distributed under the License is distributed on an | ||
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
| * KIND, either express or implied. See the License for the | ||
| * specific language governing permissions and limitations | ||
| * under the License. | ||
| */ | ||
|
|
||
| /*! | ||
| * \brief layer normalization op constructions | ||
| * \file nn/layer_norm.h | ||
| */ | ||
| #ifndef TVM_TOPI_NN_LAYER_NORM_H_ | ||
| #define TVM_TOPI_NN_LAYER_NORM_H_ | ||
|
|
||
| #include <tvm/te/operation.h> | ||
| #include <tvm/topi/tags.h> | ||
|
|
||
| #include <string> | ||
|
|
||
| namespace tvm { | ||
| namespace topi { | ||
| namespace nn { | ||
|
|
||
| using namespace tvm::te; | ||
|
|
||
| /*! | ||
| * \brief Layer normalization. | ||
| * \param data N-D tensor with shape [d_0, d_1, ..., d_{N-1}] | ||
| * \param gamma K-D tensor with shape [r_0, r_1, ..., r_{K-1}] where K == len(axis) and | ||
| * d_{axis_k} == r_k | ||
| * \param beta Optional, K-D tensor with shape [r_0, r_1, ..., r_{K-1}] where | ||
| * d_{axis_k} == r_k | ||
| * \param axis The axis to normalize over. | ||
| * \param epsilon The epsilon value to avoid division by zero. | ||
| * \param name The name of the operation. | ||
| * \param tag The tag to mark the operation. | ||
| * \return The normalized tensor, with the same shape as data. | ||
| */ | ||
| inline Tensor layer_norm(const Tensor& data, const Tensor& gamma, const Tensor& beta, | ||
| const Array<Integer>& axis, double epsilon, | ||
| std::string name = "T_layer_norm", std::string tag = kInjective) { | ||
| // sum x and x^2 | ||
| auto ndim = data->shape.size(); | ||
| ICHECK_NE(ndim, 0) << "Cannot reduce a 0 dim Tensor"; | ||
| auto real_axis = GetRealAxis(static_cast<int>(ndim), axis); | ||
| auto reduce_axes = MakeReduceAxes(real_axis, data); | ||
| auto target_shape = | ||
| MakeReduceTargetShape(real_axis, data, /*keepdims=*/false, /*atleast1d=*/true); | ||
| auto func = MakeTupleSumReducer(); | ||
|
|
||
| auto compute = [ndim, &real_axis, &reduce_axes, &func, &data](const Array<Var>& indices) { | ||
| Array<PrimExpr> eval_range; | ||
| int arg_counter = 0; | ||
| int red_counter = 0; | ||
|
|
||
| for (size_t i = 0; i < ndim; ++i) { | ||
| if (std::find(real_axis.begin(), real_axis.end(), i) != real_axis.end()) { | ||
| // real_axis contains i | ||
| eval_range.push_back(reduce_axes[red_counter]); | ||
| red_counter++; | ||
| } else { | ||
| eval_range.push_back(indices[arg_counter]); | ||
| arg_counter++; | ||
| } | ||
| } | ||
| auto square = [](const PrimExpr& x) { return x * x; }; | ||
| return func({data(eval_range), square(data(eval_range))}, reduce_axes, nullptr); | ||
| }; | ||
|
|
||
| auto temp_x_x2 = | ||
| tvm::te::compute(target_shape, compute, data->op->name + "_red_temp", kCommReduce); | ||
|
|
||
| auto temp_x = temp_x_x2[0]; | ||
| auto temp_x2 = temp_x_x2[1]; | ||
|
|
||
| auto reduce_extent = make_const(data->dtype, 1); | ||
| for (int i : real_axis) { | ||
| reduce_extent *= data->shape[i]; | ||
| } | ||
| auto layer_norm_func = [&](const Array<Var>& indices) { | ||
| Array<Var> reduce_indices, non_reduce_indices; | ||
| for (int i = 0, n = static_cast<int>(indices.size()); i < n; ++i) { | ||
| if (std::find(real_axis.begin(), real_axis.end(), i) != real_axis.end()) { | ||
| reduce_indices.push_back(indices[i]); | ||
| } else { | ||
| non_reduce_indices.push_back(indices[i]); | ||
| } | ||
| } | ||
| auto mean = temp_x(non_reduce_indices) / reduce_extent; | ||
| auto var = temp_x2(non_reduce_indices) / reduce_extent - mean * mean; | ||
| auto layer_norm = (data(indices) - mean) * tvm::rsqrt(var + make_const(var->dtype, epsilon)); | ||
| layer_norm = topi::multiply(layer_norm, gamma(reduce_indices)); | ||
| if (beta.defined()) { | ||
| layer_norm = topi::add(layer_norm, beta(reduce_indices)); | ||
| } | ||
| return layer_norm; | ||
| }; | ||
| return tvm::te::compute(data->shape, layer_norm_func, name, tag); | ||
| } | ||
|
|
||
| } // namespace nn | ||
| } // namespace topi | ||
| } // namespace tvm | ||
|
|
||
| #endif // TVM_TOPI_NN_LAYER_NORM_H_ |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,46 @@ | ||
| # Licensed to the Apache Software Foundation (ASF) under one | ||
| # or more contributor license agreements. See the NOTICE file | ||
| # distributed with this work for additional information | ||
| # regarding copyright ownership. The ASF licenses this file | ||
| # to you under the Apache License, Version 2.0 (the | ||
| # "License"); you may not use this file except in compliance | ||
| # with the License. You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, | ||
| # software distributed under the License is distributed on an | ||
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
| # KIND, either express or implied. See the License for the | ||
| # specific language governing permissions and limitations | ||
| # under the License. | ||
| """Layer normalization operator.""" | ||
| from .. import cpp | ||
|
|
||
|
|
||
| def layer_norm(data, gamma, beta, axis, epsilon=1e-5): | ||
| """Layer normalization operator. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| data : tvm.te.Tensor | ||
| N-D with shape (d_0, d_1, ..., d_{N-1}) | ||
|
|
||
| gamma: tvm.te.Tensor | ||
| K-D with shape (r_0, r_1, ..., r_{K-1}) where K == len(axis) and d_{axis_k} == r_k | ||
|
|
||
| beta: tvm.te.Tensor | ||
| Optional, K-D with shape (r_0, r_1, ..., r_{K-1}) where K == len(axis) and d_{axis_k} == r_k | ||
|
|
||
| axis : list of int | ||
| Axis over the normalization applied | ||
|
|
||
| epsilon : float | ||
| The epsilon value to avoid division by zero. | ||
|
|
||
| Returns | ||
| ------- | ||
| result : tvm.te.Tensor | ||
| N-D with shape (d_0, d_1, ..., d_{N-1}) | ||
| """ | ||
| return cpp.nn.layer_norm(data, gamma, beta, axis, epsilon) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,53 @@ | ||
| # Licensed to the Apache Software Foundation (ASF) under one | ||
| # or more contributor license agreements. See the NOTICE file | ||
| # distributed with this work for additional information | ||
| # regarding copyright ownership. The ASF licenses this file | ||
| # to you under the Apache License, Version 2.0 (the | ||
| # "License"); you may not use this file except in compliance | ||
| # with the License. You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, | ||
| # software distributed under the License is distributed on an | ||
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
| # KIND, either express or implied. See the License for the | ||
| # specific language governing permissions and limitations | ||
| # under the License. | ||
| # pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals | ||
| """Layer normalization in python""" | ||
| import numpy as np | ||
|
|
||
|
|
||
| def layer_norm_python(data, gamma, beta, axis, epsilon=1e-5): | ||
| """Layer normalization operator in Python. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| data : numpy.ndarray | ||
| N-D with shape (d_0, d_1, ..., d_{N-1}) | ||
|
|
||
| gamma: numpy.ndarray | ||
| K-D with shape (r_0, r_1, ..., r_{K-1}) where K == len(axis) and d_{axis_k} == r_k | ||
|
|
||
| beta: numpy.ndarray | ||
| Optional, K-D with shape (r_0, r_1, ..., r_{K-1}) where K == len(axis) and d_{axis_k} == r_k | ||
|
|
||
| axis : int or tuple of ints | ||
| Axis over the normalization applied | ||
|
|
||
| epsilon : float | ||
| The epsilon value to avoid division by zero. | ||
|
|
||
| Returns | ||
| ------- | ||
| result : np.ndarray | ||
| N-D with shape (d_0, d_1, ..., d_{N-1}) | ||
| """ | ||
| mean = np.mean(data, axis, keepdims=True) | ||
| var = np.var(data, axis, keepdims=True) | ||
| result = (data - mean) / np.sqrt(var + epsilon) | ||
| result *= gamma | ||
| if beta is not None: | ||
| result += beta | ||
| return result |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,62 @@ | ||
| # Licensed to the Apache Software Foundation (ASF) under one | ||
| # or more contributor license agreements. See the NOTICE file | ||
| # distributed with this work for additional information | ||
| # regarding copyright ownership. The ASF licenses this file | ||
| # to you under the Apache License, Version 2.0 (the | ||
| # "License"); you may not use this file except in compliance | ||
| # with the License. You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, | ||
| # software distributed under the License is distributed on an | ||
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
| # KIND, either express or implied. See the License for the | ||
| # specific language governing permissions and limitations | ||
| # under the License. | ||
| """Test code for layer_norm.""" | ||
| import numpy as np | ||
| import pytest | ||
| import tvm | ||
| from tvm import te | ||
| from tvm import topi | ||
| from tvm.topi.utils import get_const_tuple | ||
| import tvm.topi.testing | ||
|
|
||
| import tvm.testing | ||
|
|
||
|
|
||
| _layer_norm_schedule = { | ||
| "generic": topi.generic.schedule_injective, | ||
| } | ||
|
|
||
|
|
||
| # only test on llvm because schedule is missing | ||
| @tvm.testing.parametrize_targets("llvm") | ||
| @pytest.mark.parametrize("shape,axis", [([4, 16], (1,)), ([4, 16, 16], (1, 2))]) | ||
| def test_layer_norm(target, dev, shape, axis, episilon=1e-5, dtype="float32", rtol=1e-5, atol=1e-5): | ||
| data = te.placeholder(shape, dtype=dtype, name="data") | ||
| scale_shape = [shape[dim] for dim in axis] | ||
| gamma = te.placeholder(scale_shape, dtype=dtype, name="gamma") | ||
| beta = te.placeholder(scale_shape, dtype=dtype, name="beta") | ||
| B = topi.nn.layer_norm(data, gamma, beta, axis, episilon) | ||
|
|
||
| data_np = np.random.uniform(size=shape).astype(dtype) | ||
| gamma_np = np.random.uniform(size=scale_shape).astype(dtype) | ||
| beta_np = np.random.uniform(size=scale_shape).astype(dtype) | ||
| b_np = tvm.topi.testing.layer_norm_python(data_np, gamma_np, beta_np, axis, episilon) | ||
|
|
||
| with tvm.target.Target(target): | ||
| s_func = tvm.topi.testing.dispatch(target, _layer_norm_schedule) | ||
| s = s_func([B]) | ||
| data_tvm = tvm.nd.array(data_np, dev) | ||
| gamma_tvm = tvm.nd.array(gamma_np, dev) | ||
| beta_tvm = tvm.nd.array(beta_np, dev) | ||
| b_tvm = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), dev) | ||
| f = tvm.build(s, [data, gamma, beta, B], target) | ||
| f(data_tvm, gamma_tvm, beta_tvm, b_tvm) | ||
| tvm.testing.assert_allclose(b_tvm.asnumpy(), b_np, rtol=rtol, atol=atol) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| tvm.testing.main() |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about importing
*👀? Since I see all other imports import*.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wildcard importing is actually not a good idea though lol
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agreed, so I avoid using wildcard here. perhaps we should clean up this file in the future