From 679edc64973b4fbe35c5c6c853dfb9ae09f37120 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Wed, 21 Sep 2022 10:04:35 -0700 Subject: [PATCH 1/4] [TOPI] Add one-pass layer norm using tuple reduction --- include/tvm/topi/nn/layer_norm.h | 117 ++++++++++++++++++ include/tvm/topi/reduction.h | 23 ++++ python/tvm/topi/nn/__init__.py | 1 + python/tvm/topi/nn/layer_norm.py | 47 +++++++ python/tvm/topi/testing/__init__.py | 1 + python/tvm/topi/testing/layer_norm_python.py | 53 ++++++++ src/topi/nn.cc | 6 + .../topi/python/test_topi_layer_norm.py | 63 ++++++++++ 8 files changed, 311 insertions(+) create mode 100644 include/tvm/topi/nn/layer_norm.h create mode 100644 python/tvm/topi/nn/layer_norm.py create mode 100644 python/tvm/topi/testing/layer_norm_python.py create mode 100644 tests/python/topi/python/test_topi_layer_norm.py diff --git a/include/tvm/topi/nn/layer_norm.h b/include/tvm/topi/nn/layer_norm.h new file mode 100644 index 000000000000..f7b82733887c --- /dev/null +++ b/include/tvm/topi/nn/layer_norm.h @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \brief layer normalization op constructions + * \file nn/layer_norm.h + */ +#ifndef TVM_TOPI_NN_LAYER_NORM_H_ +#define TVM_TOPI_NN_LAYER_NORM_H_ + +#include +#include + +#include + +namespace tvm { +namespace topi { +namespace nn { + +using namespace tvm::te; + +/*! + * \brief Layer normalization. + * \param data N-D tensor with shape [d_0, d_1, ..., d_n] + * \param gamma R-D tensor with shape [r_0, r_1, ..., r_k] where R == len(axis) and d_{axis_i} == + * r_i + * \param beta Optional, R-D tensor with shape [r_0, r_1, ..., r_k] where R == len(axis) and + * d_{axis_i} == r_i + * \param axis The axis to normalize over. + * \param epsilon The epsilon value to avoid division by zero. + * \param name The name of the operation. + * \param tag The tag to mark the operation. + * \return The normalized tensor, with the same shape as data. + */ +inline Tensor layer_norm(const Tensor& data, const Tensor& gamma, const Tensor& beta, + const Array& axis, double epsilon, + std::string name = "T_layer_norm", std::string tag = kInjective) { + // sum x and x^2 + auto ndim = data->shape.size(); + ICHECK_NE(ndim, 0) << "Cannot reduce a 0 dim Tensor"; + auto real_axis = GetRealAxis(static_cast(ndim), axis); + auto reduce_axes = MakeReduceAxes(real_axis, data); + auto target_shape = + MakeReduceTargetShape(real_axis, data, /*keepdims=*/false, /*atleast1d=*/true); + auto func = MakeTupleSumReducer(); + + auto compute = [ndim, &real_axis, &reduce_axes, &func, &data](const Array& indices) { + Array eval_range; + int arg_counter = 0; + int red_counter = 0; + + for (size_t i = 0; i < ndim; ++i) { + if (std::find(real_axis.begin(), real_axis.end(), i) != real_axis.end()) { + // real_axis contains i + eval_range.push_back(reduce_axes[red_counter]); + red_counter++; + } else { + eval_range.push_back(indices[arg_counter]); + arg_counter++; + } + } + auto square = [](const PrimExpr& x) { return x * x; }; + return func({data(eval_range), square(data(eval_range))}, reduce_axes, nullptr); + }; + + auto temp_x_x2 = + tvm::te::compute(target_shape, compute, data->op->name + "_red_temp", kCommReduce); + + auto temp_x = temp_x_x2[0]; + auto temp_x2 = temp_x_x2[1]; + + auto reduce_extent = make_const(data->dtype, 1); + for (int i : real_axis) { + reduce_extent *= data->shape[i]; + } + auto layer_norm_func = [&](const Array& indices) { + Array reduce_indices, non_reduce_indices; + for (int i = 0, n = static_cast(indices.size()); i < n; ++i) { + if (std::find(real_axis.begin(), real_axis.end(), i) != real_axis.end()) { + reduce_indices.push_back(indices[i]); + } else { + non_reduce_indices.push_back(indices[i]); + } + } + auto mean = temp_x(non_reduce_indices) / reduce_extent; + auto var = temp_x2(non_reduce_indices) / reduce_extent - mean * mean; + auto layer_norm = (data(indices) - mean) * tvm::rsqrt(var + make_const(var->dtype, epsilon)); + layer_norm = topi::multiply(layer_norm, gamma(reduce_indices)); + if (beta.defined()) { + layer_norm = topi::add(layer_norm, beta(reduce_indices)); + } + return layer_norm; + }; + return tvm::te::compute(data->shape, layer_norm_func, name, tag); +}; + +} // namespace nn +} // namespace topi +} // namespace tvm + +#endif // TVM_TOPI_NN_LAYER_NORM_H_ \ No newline at end of file diff --git a/include/tvm/topi/reduction.h b/include/tvm/topi/reduction.h index d4e420d80b02..5e79bd429d6f 100644 --- a/include/tvm/topi/reduction.h +++ b/include/tvm/topi/reduction.h @@ -570,6 +570,29 @@ inline Tensor prod(const Tensor& data, const Array& axis, bool keepdims return CommReduce(data, axis, ProdOp, keepdims, atleast1d); } +/*! + * \brief Create communitive reducer summing over tuples + */ +inline FCommReduce MakeTupleSumReducer() { + auto fcombine = [](Array lhs, Array rhs) { + Array result; + ICHECK_EQ(lhs.size(), rhs.size()); + result.reserve(lhs.size()); + for (size_t i = 0; i < lhs.size(); ++i) { + result.push_back(lhs[i] + rhs[i]); + } + return result; + }; + auto fidentity = [](std::vector types) { + Array result; + for (size_t i = 0; i < types.size(); ++i) { + result.push_back(tvm::tir::make_const(types[i], 0)); + } + return result; + }; + return MakeCommReducer(fcombine, fidentity, "tuple_sum"); +} + } // namespace topi } // namespace tvm #endif // TVM_TOPI_REDUCTION_H_ diff --git a/python/tvm/topi/nn/__init__.py b/python/tvm/topi/nn/__init__.py index 1dd922d76819..8f081242fa10 100644 --- a/python/tvm/topi/nn/__init__.py +++ b/python/tvm/topi/nn/__init__.py @@ -38,6 +38,7 @@ from .bnn import * from .qnn import * from .upsampling import * +from .layer_norm import layer_norm from .local_response_norm import * from .bitserial_conv2d import * from .bitserial_dense import * diff --git a/python/tvm/topi/nn/layer_norm.py b/python/tvm/topi/nn/layer_norm.py new file mode 100644 index 000000000000..245fbcb47cd9 --- /dev/null +++ b/python/tvm/topi/nn/layer_norm.py @@ -0,0 +1,47 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""TVM operator flatten compute.""" +import tvm +from .. import tag, cpp + + +def layer_norm(data, gamma, beta, axis, epsilon=1e-5): + """Layer normalization operator. + + Parameters + ---------- + data : tvm.te.Tensor + N-D with shape (d_0, d_1, ..., d_n) + + gamma: tvm.te.Tensor + R-D with shape (r_0, r_1, ..., r_k) where R == len(axis) and d_{axis_i} == r_i + + beta: tvm.te.Tensor + Optional, R-D with shape (r_0, r_1, ..., r_k) where R == len(axis) and d_{axis_i} == r_i + + axis : list of int + Axis over the normalization applied + + epsilon : float + The epsilon value to avoid division by zero. + + Returns + ------- + result : tvm.te.Tensor + N-D with shape (d_0, d_1, ..., d_n) + """ + return cpp.nn.layer_norm(data, gamma, beta, axis, epsilon) \ No newline at end of file diff --git a/python/tvm/topi/testing/__init__.py b/python/tvm/topi/testing/__init__.py index 2f091cba10b7..2922c30b505c 100644 --- a/python/tvm/topi/testing/__init__.py +++ b/python/tvm/topi/testing/__init__.py @@ -43,6 +43,7 @@ from .reorg_python import reorg_python from .roi_align_python import roi_align_nchw_python, roi_align_nhwc_python from .roi_pool_python import roi_pool_nchw_python +from .layer_norm_python import layer_norm_python from .lrn_python import lrn_python from .l2_normalize_python import l2_normalize_python from .gather_python import gather_python diff --git a/python/tvm/topi/testing/layer_norm_python.py b/python/tvm/topi/testing/layer_norm_python.py new file mode 100644 index 000000000000..263a17ff71df --- /dev/null +++ b/python/tvm/topi/testing/layer_norm_python.py @@ -0,0 +1,53 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals +"""Layer normalization in python""" +import numpy as np + + +def layer_norm_python(data, gamma, beta, axis, epsilon=1e-5): + """Layer normalization operator in Python. + + Parameters + ---------- + data : numpy.ndarray + N-D with shape (d_0, d_1, ..., d_n) + + gamma: numpy.ndarray + R-D with shape (r_0, r_1, ..., r_r) where R == len(axis) and d_{axis_i} == r_i + + beta: numpy.ndarray + Optional, R-D with shape (r_0, r_1, ..., r_r) where R == len(axis) and d_{axis_i} == r_i + + axis : list of int + Axis over the normalization applied + + epsilon : float + The epsilon value to avoid division by zero. + + Returns + ------- + result : np.ndarray + N-D with shape (d_0, d_1, ..., d_n) + """ + mean = np.mean(data, axis, keepdims=True) + var = np.var(data, axis, keepdims=True) + result = (data - mean) / np.sqrt(var + epsilon) + result *= gamma + if beta is not None: + result += beta + return result diff --git a/src/topi/nn.cc b/src/topi/nn.cc index 2950aee4e90d..35dbf3a03e4f 100644 --- a/src/topi/nn.cc +++ b/src/topi/nn.cc @@ -29,6 +29,7 @@ #include #include #include +#include #include #include #include @@ -157,5 +158,10 @@ TVM_REGISTER_GLOBAL("topi.nn.binary_dense").set_body([](TVMArgs args, TVMRetValu *rv = nn::binary_dense(args[0], args[1]); }); +/* Ops from nn/layer_norm.h */ +TVM_REGISTER_GLOBAL("topi.nn.layer_norm").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = nn::layer_norm(args[0], args[1], args[2], args[3], static_cast(args[4])); +}); + } // namespace topi } // namespace tvm diff --git a/tests/python/topi/python/test_topi_layer_norm.py b/tests/python/topi/python/test_topi_layer_norm.py new file mode 100644 index 000000000000..9910e514e49d --- /dev/null +++ b/tests/python/topi/python/test_topi_layer_norm.py @@ -0,0 +1,63 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Test code for layer_norm.""" +import numpy as np +import pytest +import tvm +from tvm import te +from tvm import topi +from tvm.topi.utils import get_const_tuple +import tvm.topi.testing + +import tvm.testing + + +_layer_norm_schedule = { + "generic": topi.generic.schedule_injective, +} + + +# only test on llvm because schedule is missing +@tvm.testing.parametrize_targets("llvm") +@pytest.mark.parametrize("shape,axis", [([4, 16], (1,)), ([4, 16, 16], (1, 2))]) +def test_layer_norm(target, dev, shape, axis, episilon=1e-5, dtype="float32", rtol=1e-5, atol=1e-5): + data = te.placeholder(shape, dtype=dtype, name="data") + scale_shape = [shape[dim] for dim in axis] + gamma = te.placeholder(scale_shape, dtype=dtype, name="gamma") + beta = te.placeholder(scale_shape, dtype=dtype, name="beta") + B = topi.nn.layer_norm(data, gamma, beta, axis, episilon) + + data_np = np.random.uniform(size=shape).astype(dtype) + gamma_np = np.random.uniform(size=scale_shape).astype(dtype) + beta_np = np.random.uniform(size=scale_shape).astype(dtype) + b_np = tvm.topi.testing.layer_norm_python(data_np, gamma_np, beta_np, axis, episilon) + + print("Running on target: %s" % target) + with tvm.target.Target(target): + s_func = tvm.topi.testing.dispatch(target, _layer_norm_schedule) + s = s_func([B]) + data_tvm = tvm.nd.array(data_np, dev) + gamma_tvm = tvm.nd.array(gamma_np, dev) + beta_tvm = tvm.nd.array(beta_np, dev) + b_tvm = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), dev) + f = tvm.build(s, [data, gamma, beta, B], target) + f(data_tvm, gamma_tvm, beta_tvm, b_tvm) + tvm.testing.assert_allclose(b_tvm.asnumpy(), b_np, rtol=rtol, atol=atol) + + +if __name__ == "__main__": + tvm.testing.main() From b2d0735189ec7f3f9882df8b7979a21476677fa9 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Wed, 21 Sep 2022 14:00:00 -0700 Subject: [PATCH 2/4] Add reducer pattern for LowerCrossThreadReduction --- python/tvm/topi/nn/layer_norm.py | 2 +- python/tvm/topi/testing/layer_norm_python.py | 2 +- src/tir/schedule/primitive/reduction.cc | 9 ++ .../topi/python/test_topi_layer_norm.py | 1 - ..._transform_lower_cross_thread_reduction.py | 149 ++++++++++++++++++ 5 files changed, 160 insertions(+), 3 deletions(-) diff --git a/python/tvm/topi/nn/layer_norm.py b/python/tvm/topi/nn/layer_norm.py index 245fbcb47cd9..f7c446be80b3 100644 --- a/python/tvm/topi/nn/layer_norm.py +++ b/python/tvm/topi/nn/layer_norm.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""TVM operator flatten compute.""" +"""Layer normalization operator.""" import tvm from .. import tag, cpp diff --git a/python/tvm/topi/testing/layer_norm_python.py b/python/tvm/topi/testing/layer_norm_python.py index 263a17ff71df..419fbae31af6 100644 --- a/python/tvm/topi/testing/layer_norm_python.py +++ b/python/tvm/topi/testing/layer_norm_python.py @@ -33,7 +33,7 @@ def layer_norm_python(data, gamma, beta, axis, epsilon=1e-5): beta: numpy.ndarray Optional, R-D with shape (r_0, r_1, ..., r_r) where R == len(axis) and d_{axis_i} == r_i - axis : list of int + axis : int or tuple of ints Axis over the normalization applied epsilon : float diff --git a/src/tir/schedule/primitive/reduction.cc b/src/tir/schedule/primitive/reduction.cc index dd2bcf727c40..bb43df1ce914 100644 --- a/src/tir/schedule/primitive/reduction.cc +++ b/src/tir/schedule/primitive/reduction.cc @@ -330,6 +330,15 @@ struct ReducerRegistry { [](const Array& values) { return Array{min_value(values[0]->dtype)}; }), + CreateReducerGetter( + /*n_buffers=*/2, + [](const Array& x, const Array& y) { + return Array{x[0] + y[0], x[1] + y[1]}; + }, + [](const Array& values) { + return Array{make_const(values[0]->dtype, 0), + make_const(values[1]->dtype, 0)}; + }), CreateReducerGetter( /*n_buffers=*/2, [](const Array& x, const Array& y) { diff --git a/tests/python/topi/python/test_topi_layer_norm.py b/tests/python/topi/python/test_topi_layer_norm.py index 9910e514e49d..ead05470be3b 100644 --- a/tests/python/topi/python/test_topi_layer_norm.py +++ b/tests/python/topi/python/test_topi_layer_norm.py @@ -46,7 +46,6 @@ def test_layer_norm(target, dev, shape, axis, episilon=1e-5, dtype="float32", rt beta_np = np.random.uniform(size=scale_shape).astype(dtype) b_np = tvm.topi.testing.layer_norm_python(data_np, gamma_np, beta_np, axis, episilon) - print("Running on target: %s" % target) with tvm.target.Target(target): s_func = tvm.topi.testing.dispatch(target, _layer_norm_schedule) s = s_func([B]) diff --git a/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py b/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py index ff1353d2265e..8c139b710e23 100644 --- a/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py +++ b/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py @@ -1002,6 +1002,151 @@ def lowered_argmin_split_init_update_reordered( argmin_v1[i] = cross_thread_argmin_v1[0] +@T.prim_func +def layer_norm_tuple_sum( + data: T.Buffer[(128, 768), "float32"], + gamma: T.Buffer[768, "float32"], + bias: T.Buffer[768, "float32"], + T_layer_norm: T.Buffer[(128, 768), "float32"], +) -> None: + data_red_temp_v0 = T.alloc_buffer([128], dtype="float32") + data_red_temp_v1 = T.alloc_buffer([128], dtype="float32") + for i0_fused in T.thread_binding(128, thread="blockIdx.x"): + for i1_0 in T.serial(24): + for i1_1 in T.thread_binding(32, thread="threadIdx.x"): + with T.block("data_red_temp"): + ax0 = T.axis.spatial(128, i0_fused) + k1 = T.axis.reduce(768, i1_0 * 32 + i1_1) + T.reads(data[ax0, k1]) + T.writes(data_red_temp_v0[ax0], data_red_temp_v1[ax0]) + with T.init(): + data_red_temp_v0[ax0] = T.float32(0) + data_red_temp_v1[ax0] = T.float32(0) + v_data_red_temp_v0: T.float32 = data_red_temp_v0[ax0] + data[ax0, k1] + v_data_red_temp_v1: T.float32 = ( + data_red_temp_v1[ax0] + data[ax0, k1] * data[ax0, k1] + ) + data_red_temp_v0[ax0] = v_data_red_temp_v0 + data_red_temp_v1[ax0] = v_data_red_temp_v1 + for i0_i1_fused_0 in T.thread_binding(384, thread="blockIdx.x"): + for i0_i1_fused_1 in T.thread_binding(256, thread="threadIdx.x"): + with T.block("T_layer_norm"): + ax0 = T.axis.spatial(128, (i0_i1_fused_0 * 256 + i0_i1_fused_1) // 768) + ax1 = T.axis.spatial(768, (i0_i1_fused_0 * 256 + i0_i1_fused_1) % 768) + T.reads( + data[ax0, ax1], + data_red_temp_v0[ax0], + data_red_temp_v1[ax0], + gamma[ax1], + bias[ax1], + ) + T.writes(T_layer_norm[ax0, ax1]) + T_layer_norm[ax0, ax1] = ( + data[ax0, ax1] - data_red_temp_v0[ax0] * T.float32(0.0013020833333333333) + ) * T.rsqrt( + data_red_temp_v1[ax0] * T.float32(0.0013020833333333333) + - data_red_temp_v0[ax0] + * T.float32(0.0013020833333333333) + * (data_red_temp_v0[ax0] * T.float32(0.0013020833333333333)) + + T.float32(1.0000000000000001e-05), + dtype="float32", + ) * gamma[ + ax1 + ] + bias[ + ax1 + ] + + +@T.prim_func +def lowered_layer_norm_tuple_sum( + data: T.Buffer[(128, 768), "float32"], + gamma: T.Buffer[768, "float32"], + bias: T.Buffer[768, "float32"], + T_layer_norm: T.Buffer[(128, 768), "float32"], +) -> None: + # with T.block("root") + data_red_temp_v0 = T.alloc_buffer([128], dtype="float32") + data_red_temp_v1 = T.alloc_buffer([128], dtype="float32") + cross_thread_data_red_temp_v0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") + cross_thread_data_red_temp_v1 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") + in_thread_data_red_temp_v0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") + in_thread_data_red_temp_v1 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") + for i0_fused in T.thread_binding(128, thread="blockIdx.x"): + for i1_1 in T.thread_binding(32, thread="threadIdx.x"): + with T.block("data_red_temp_in_thread_init"): + T.reads() + T.writes(in_thread_data_red_temp_v0[0], in_thread_data_red_temp_v1[0]) + in_thread_data_red_temp_v0[0] = T.float32(0) + in_thread_data_red_temp_v1[0] = T.float32(0) + for i1_0 in T.serial(24): + with T.block("data_red_temp_in_thread"): + ax0 = T.axis.spatial(128, i0_fused) + k1 = T.axis.reduce(768, i1_0 * 32 + i1_1) + T.reads(data[ax0, k1]) + T.writes(in_thread_data_red_temp_v0[0], in_thread_data_red_temp_v1[0]) + v_data_red_temp_v0: T.float32 = in_thread_data_red_temp_v0[0] + data[ax0, k1] + v_data_red_temp_v1: T.float32 = ( + in_thread_data_red_temp_v1[0] + data[ax0, k1] * data[ax0, k1] + ) + in_thread_data_red_temp_v0[0] = v_data_red_temp_v0 + in_thread_data_red_temp_v1[0] = v_data_red_temp_v1 + with T.block("data_red_temp_cross_thread"): + T.reads(in_thread_data_red_temp_v0[0], in_thread_data_red_temp_v1[0]) + T.writes(cross_thread_data_red_temp_v0[0], cross_thread_data_red_temp_v1[0]) + T.attr( + T.comm_reducer( + lambda x0, x1, y0, y1: (x0 + y0, x1 + y1), [T.float32(0), T.float32(0)] + ), + "reduce_scope", + T.reinterpret(T.uint64(0), dtype="handle"), + ) + T.evaluate( + T.tvm_thread_allreduce( + T.uint32(2), + in_thread_data_red_temp_v0[0], + in_thread_data_red_temp_v1[0], + True, + cross_thread_data_red_temp_v0[0], + cross_thread_data_red_temp_v1[0], + i1_1, + dtype="handle", + ) + ) + with T.block("data_red_temp_write_back"): + ax0 = T.axis.spatial(128, i0_fused) + T.reads(cross_thread_data_red_temp_v0[0], cross_thread_data_red_temp_v1[0]) + T.writes(data_red_temp_v0[ax0], data_red_temp_v1[ax0]) + data_red_temp_v0[ax0] = cross_thread_data_red_temp_v0[0] + data_red_temp_v1[ax0] = cross_thread_data_red_temp_v1[0] + for i0_i1_fused_0 in T.thread_binding(384, thread="blockIdx.x"): + for i0_i1_fused_1 in T.thread_binding(256, thread="threadIdx.x"): + with T.block("T_layer_norm"): + ax0 = T.axis.spatial(128, (i0_i1_fused_0 * 256 + i0_i1_fused_1) // 768) + ax1 = T.axis.spatial(768, (i0_i1_fused_0 * 256 + i0_i1_fused_1) % 768) + T.reads( + data[ax0, ax1], + data_red_temp_v0[ax0], + data_red_temp_v1[ax0], + gamma[ax1], + bias[ax1], + ) + T.writes(T_layer_norm[ax0, ax1]) + T_layer_norm[ax0, ax1] = ( + data[ax0, ax1] - data_red_temp_v0[ax0] * T.float32(0.0013020833333333333) + ) * T.rsqrt( + data_red_temp_v1[ax0] * T.float32(0.0013020833333333333) + - data_red_temp_v0[ax0] + * T.float32(0.0013020833333333333) + * (data_red_temp_v0[ax0] * T.float32(0.0013020833333333333)) + + T.float32(1.0000000000000001e-05), + dtype="float32", + ) * gamma[ + ax1 + ] + bias[ + ax1 + ] + + # pylint: enable=no-member,invalid-name,unused-variable,unexpected-keyword-arg @@ -1087,5 +1232,9 @@ def test_lower_te(): ) # LowerCrossThreadReduction should do nothing on TE +def test_layer_norm_tuple_sum(): + _check(layer_norm_tuple_sum, lowered_layer_norm_tuple_sum) + + if __name__ == "__main__": tvm.testing.main() From d2219a7117fb012c6c6d6d59cc1219178530c90d Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Wed, 21 Sep 2022 14:20:57 -0700 Subject: [PATCH 3/4] lint --- include/tvm/topi/nn/layer_norm.h | 4 ++-- python/tvm/topi/nn/layer_norm.py | 5 ++--- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/include/tvm/topi/nn/layer_norm.h b/include/tvm/topi/nn/layer_norm.h index f7b82733887c..33911207cb8a 100644 --- a/include/tvm/topi/nn/layer_norm.h +++ b/include/tvm/topi/nn/layer_norm.h @@ -108,10 +108,10 @@ inline Tensor layer_norm(const Tensor& data, const Tensor& gamma, const Tensor& return layer_norm; }; return tvm::te::compute(data->shape, layer_norm_func, name, tag); -}; +} } // namespace nn } // namespace topi } // namespace tvm -#endif // TVM_TOPI_NN_LAYER_NORM_H_ \ No newline at end of file +#endif // TVM_TOPI_NN_LAYER_NORM_H_ diff --git a/python/tvm/topi/nn/layer_norm.py b/python/tvm/topi/nn/layer_norm.py index f7c446be80b3..5d252267a69d 100644 --- a/python/tvm/topi/nn/layer_norm.py +++ b/python/tvm/topi/nn/layer_norm.py @@ -15,8 +15,7 @@ # specific language governing permissions and limitations # under the License. """Layer normalization operator.""" -import tvm -from .. import tag, cpp +from .. import cpp def layer_norm(data, gamma, beta, axis, epsilon=1e-5): @@ -44,4 +43,4 @@ def layer_norm(data, gamma, beta, axis, epsilon=1e-5): result : tvm.te.Tensor N-D with shape (d_0, d_1, ..., d_n) """ - return cpp.nn.layer_norm(data, gamma, beta, axis, epsilon) \ No newline at end of file + return cpp.nn.layer_norm(data, gamma, beta, axis, epsilon) From 30a48ffc764e1bf95accecf46336afd6b255168f Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Thu, 22 Sep 2022 10:29:31 -0700 Subject: [PATCH 4/4] update docs --- include/tvm/topi/nn/layer_norm.h | 10 +++++----- python/tvm/topi/nn/layer_norm.py | 8 ++++---- python/tvm/topi/testing/layer_norm_python.py | 8 ++++---- 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/include/tvm/topi/nn/layer_norm.h b/include/tvm/topi/nn/layer_norm.h index 33911207cb8a..93e5582ef184 100644 --- a/include/tvm/topi/nn/layer_norm.h +++ b/include/tvm/topi/nn/layer_norm.h @@ -37,11 +37,11 @@ using namespace tvm::te; /*! * \brief Layer normalization. - * \param data N-D tensor with shape [d_0, d_1, ..., d_n] - * \param gamma R-D tensor with shape [r_0, r_1, ..., r_k] where R == len(axis) and d_{axis_i} == - * r_i - * \param beta Optional, R-D tensor with shape [r_0, r_1, ..., r_k] where R == len(axis) and - * d_{axis_i} == r_i + * \param data N-D tensor with shape [d_0, d_1, ..., d_{N-1}] + * \param gamma K-D tensor with shape [r_0, r_1, ..., r_{K-1}] where K == len(axis) and + * d_{axis_k} == r_k + * \param beta Optional, K-D tensor with shape [r_0, r_1, ..., r_{K-1}] where + * d_{axis_k} == r_k * \param axis The axis to normalize over. * \param epsilon The epsilon value to avoid division by zero. * \param name The name of the operation. diff --git a/python/tvm/topi/nn/layer_norm.py b/python/tvm/topi/nn/layer_norm.py index 5d252267a69d..3bdeaaac61a5 100644 --- a/python/tvm/topi/nn/layer_norm.py +++ b/python/tvm/topi/nn/layer_norm.py @@ -24,13 +24,13 @@ def layer_norm(data, gamma, beta, axis, epsilon=1e-5): Parameters ---------- data : tvm.te.Tensor - N-D with shape (d_0, d_1, ..., d_n) + N-D with shape (d_0, d_1, ..., d_{N-1}) gamma: tvm.te.Tensor - R-D with shape (r_0, r_1, ..., r_k) where R == len(axis) and d_{axis_i} == r_i + K-D with shape (r_0, r_1, ..., r_{K-1}) where K == len(axis) and d_{axis_k} == r_k beta: tvm.te.Tensor - Optional, R-D with shape (r_0, r_1, ..., r_k) where R == len(axis) and d_{axis_i} == r_i + Optional, K-D with shape (r_0, r_1, ..., r_{K-1}) where K == len(axis) and d_{axis_k} == r_k axis : list of int Axis over the normalization applied @@ -41,6 +41,6 @@ def layer_norm(data, gamma, beta, axis, epsilon=1e-5): Returns ------- result : tvm.te.Tensor - N-D with shape (d_0, d_1, ..., d_n) + N-D with shape (d_0, d_1, ..., d_{N-1}) """ return cpp.nn.layer_norm(data, gamma, beta, axis, epsilon) diff --git a/python/tvm/topi/testing/layer_norm_python.py b/python/tvm/topi/testing/layer_norm_python.py index 419fbae31af6..6b3b00146983 100644 --- a/python/tvm/topi/testing/layer_norm_python.py +++ b/python/tvm/topi/testing/layer_norm_python.py @@ -25,13 +25,13 @@ def layer_norm_python(data, gamma, beta, axis, epsilon=1e-5): Parameters ---------- data : numpy.ndarray - N-D with shape (d_0, d_1, ..., d_n) + N-D with shape (d_0, d_1, ..., d_{N-1}) gamma: numpy.ndarray - R-D with shape (r_0, r_1, ..., r_r) where R == len(axis) and d_{axis_i} == r_i + K-D with shape (r_0, r_1, ..., r_{K-1}) where K == len(axis) and d_{axis_k} == r_k beta: numpy.ndarray - Optional, R-D with shape (r_0, r_1, ..., r_r) where R == len(axis) and d_{axis_i} == r_i + Optional, K-D with shape (r_0, r_1, ..., r_{K-1}) where K == len(axis) and d_{axis_k} == r_k axis : int or tuple of ints Axis over the normalization applied @@ -42,7 +42,7 @@ def layer_norm_python(data, gamma, beta, axis, epsilon=1e-5): Returns ------- result : np.ndarray - N-D with shape (d_0, d_1, ..., d_n) + N-D with shape (d_0, d_1, ..., d_{N-1}) """ mean = np.mean(data, axis, keepdims=True) var = np.var(data, axis, keepdims=True)