diff --git a/include/tvm/topi/nn/layer_norm.h b/include/tvm/topi/nn/layer_norm.h index 93e5582ef184..371d8e1c8102 100644 --- a/include/tvm/topi/nn/layer_norm.h +++ b/include/tvm/topi/nn/layer_norm.h @@ -38,7 +38,7 @@ using namespace tvm::te; /*! * \brief Layer normalization. * \param data N-D tensor with shape [d_0, d_1, ..., d_{N-1}] - * \param gamma K-D tensor with shape [r_0, r_1, ..., r_{K-1}] where K == len(axis) and + * \param gamma Optional, K-D tensor with shape [r_0, r_1, ..., r_{K-1}] where K == len(axis) and * d_{axis_k} == r_k * \param beta Optional, K-D tensor with shape [r_0, r_1, ..., r_{K-1}] where * d_{axis_k} == r_k @@ -101,7 +101,9 @@ inline Tensor layer_norm(const Tensor& data, const Tensor& gamma, const Tensor& auto mean = temp_x(non_reduce_indices) / reduce_extent; auto var = temp_x2(non_reduce_indices) / reduce_extent - mean * mean; auto layer_norm = (data(indices) - mean) * tvm::rsqrt(var + make_const(var->dtype, epsilon)); - layer_norm = topi::multiply(layer_norm, gamma(reduce_indices)); + if (gamma.defined()) { + layer_norm = topi::multiply(layer_norm, gamma(reduce_indices)); + } if (beta.defined()) { layer_norm = topi::add(layer_norm, beta(reduce_indices)); } diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index e956c82828c1..e552b71edeee 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -47,6 +47,9 @@ # log_softmax reg.register_strategy("nn.log_softmax", strategy.log_softmax_strategy) +# layer_norm +reg.register_strategy("nn.layer_norm", strategy.layer_norm_strategy) + @reg.register_legalize("nn.matmul") def legalize_matmul(attrs, inputs, types): diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index cc438092666a..c130dd3ba0a1 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -125,6 +125,17 @@ def log_softmax_strategy_cuda(attrs, inputs, out_type, target): return strategy +@layer_norm_strategy.register(["cuda", "gpu"]) +def layer_norm_strategy_cuda(attrs, inputs, out_type, target): + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_layer_norm_strategy(topi.nn.layer_norm), + wrap_topi_schedule(topi.cuda.schedule_injective), + name="layer_norm.cuda", + ) + return strategy + + @schedule_lrn.register(["cuda", "gpu"]) def schedule_lrn_cuda(attrs, outs, target): """schedule LRN for cuda""" diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index 1cf55f7145cd..f37503d76100 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -197,6 +197,36 @@ def log_softmax_strategy(attrs, inputs, out_type, target): return strategy +def wrap_layer_norm_strategy(topi_compute): + """Wrap softmax topi compute""" + + def _compute_layer_norm(attrs, inputs, out_type): + axis = attrs.axis + epsilon = attrs.epsilon + return [ + topi_compute( + inputs[0], + inputs[1] if attrs.scale else None, + inputs[2] if attrs.center else None, + [axis], + epsilon, + ) + ] + + return _compute_layer_norm + + +@override_native_generic_func("layer_norm_strategy") +def layer_norm_strategy(attrs, inputs, out_type, target): + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_layer_norm_strategy(topi.nn.layer_norm), + wrap_topi_schedule(topi.generic.schedule_injective), + name="layer_norm.generic", + ) + return strategy + + # lrn @generic_func def schedule_lrn(attrs, outs, target): diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index 9e2fe63b006a..2165e13bc406 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -1024,7 +1024,8 @@ RELAY_REGISTER_OP("nn.layer_norm") .set_attr("FInferCorrectLayout", NormalizationInferCorrectLayout) .set_support_level(1) - .add_type_rel("LayerNorm", LayerNormRel); + .add_type_rel("LayerNorm", LayerNormRel) + .set_attr("TOpPattern", kOutEWiseFusable); // group_norm TVM_REGISTER_NODE_TYPE(GroupNormAttrs); diff --git a/src/relay/transforms/simplify_inference.cc b/src/relay/transforms/simplify_inference.cc index e7eef41e41c4..ab587a902caf 100644 --- a/src/relay/transforms/simplify_inference.cc +++ b/src/relay/transforms/simplify_inference.cc @@ -115,29 +115,6 @@ Expr GroupNormToInferUnpack(const Attrs attrs, Expr data, Expr gamma, Expr beta, return out; } -Expr LayerNormToInferUnpack(const Attrs attrs, Expr data, Expr gamma, Expr beta, Type tdata) { - auto ttype = tdata.as(); - ICHECK(ttype); - const auto param = attrs.as(); - ICHECK(param); - - Expr epsilon = MakeConstantScalar(ttype->dtype, static_cast(param->epsilon)); - Expr mean = Mean(data, {param->axis}, true, false); - Expr var = Variance(data, mean, {param->axis}, true, false); - Expr denom = Sqrt(Add(var, epsilon)); - Expr out = Divide(Subtract(data, mean), denom); - - size_t ndim = ttype->shape.size(); - int axis = (param->axis < 0) ? param->axis + ndim : param->axis; - if (param->scale) { - out = Multiply(out, ExpandBiasToMatchAxis(gamma, ndim, {axis})); - } - if (param->center) { - out = Add(out, ExpandBiasToMatchAxis(beta, ndim, {axis})); - } - return out; -} - Expr InstanceNormToInferUnpack(const Attrs attrs, Expr data, Expr gamma, Expr beta, Type tdata) { auto ttype = tdata.as(); ICHECK(ttype); @@ -207,10 +184,6 @@ class InferenceSimplifier : public MixedModeMutator { Expr Rewrite_(const CallNode* n, const Expr& new_n) { if (n->op == batch_norm_op_) { ty_map_[new_n.as()->args[0]] = n->args[0]->checked_type(); - } else if (n->op == layer_norm_op_) { - const auto* call = new_n.as(); - return LayerNormToInferUnpack(call->attrs, call->args[0], call->args[1], call->args[2], - n->args[0]->checked_type()); } else if (n->op == group_norm_op_) { const auto* call = new_n.as(); return GroupNormToInferUnpack(call->attrs, call->args[0], call->args[1], call->args[2], diff --git a/src/tir/schedule/primitive/get_block_loop.cc b/src/tir/schedule/primitive/get_block_loop.cc index 72f43a8d4929..d30b705fcec9 100644 --- a/src/tir/schedule/primitive/get_block_loop.cc +++ b/src/tir/schedule/primitive/get_block_loop.cc @@ -80,12 +80,34 @@ Array GetChildBlocks(const ScheduleState& self, const StmtSRef& parent Array GetProducers(const ScheduleState& self, const StmtSRef& block_sref) { StmtSRef scope_root = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false); - return tir::GetProducers(block_sref, self->GetBlockScope(scope_root)); + Array edges = self->GetBlockScope(scope_root)->GetDepsByDst(block_sref); + Array results; + std::unordered_set result_set; + results.reserve(edges.size()); + for (const Dependency& edge : edges) { + if ((edge->kind == DepKind::kRAW || edge->kind == DepKind::kWAW) && + !result_set.count(edge->src)) { + results.push_back(edge->src); + result_set.emplace(edge->src); + } + } + return results; } Array GetConsumers(const ScheduleState& self, const StmtSRef& block_sref) { StmtSRef scope_root = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false); - return tir::GetConsumers(block_sref, self->GetBlockScope(scope_root)); + Array edges = self->GetBlockScope(scope_root)->GetDepsBySrc(block_sref); + Array results; + std::unordered_set result_set; + results.reserve(edges.size()); + for (const Dependency& edge : edges) { + if ((edge->kind == DepKind::kRAW || edge->kind == DepKind::kWAW) && + !result_set.count(edge->dst)) { + results.push_back(edge->dst); + result_set.emplace(edge->dst); + } + } + return results; } /******** InstructionKind Registration ********/ diff --git a/tests/python/relay/test_op_level1.py b/tests/python/relay/test_op_level1.py index bd4e1b72c3cd..de5f526da147 100644 --- a/tests/python/relay/test_op_level1.py +++ b/tests/python/relay/test_op_level1.py @@ -14,6 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import itertools +from math import gamma import numpy as np import pytest import tvm @@ -172,6 +174,56 @@ def check_binary_op(opfunc, ref, dtype): check_binary_op(opfunc, ref, dtype) +@tvm.testing.uses_gpu +def test_layer_norm(): + from tvm.topi.testing import layer_norm_python + + # based on topi test + def verify_layer_norm(dshape, dtype, gamma, beta, axis, center, scale, rtol, atol): + x = relay.Var("x", relay.TensorType(dshape, dtype)) + func = relay.Function( + [x], + relay.nn.layer_norm( + x, + relay.const([gamma] * dshape[axis], dtype=dtype), + relay.const([beta] * dshape[axis], dtype=dtype), + axis=axis, + center=center, + scale=scale, + ), + ) + for target, dev in tvm.testing.enabled_targets(): + if "cuda" in target or "nvptx" in target: + # CUDA needs tuning to work + continue + data = np.random.uniform(size=dshape).astype(dtype) + ref_res = layer_norm_python(data, gamma if scale else 1, beta if center else 0, axis) + op_res = relay.create_executor("graph", device=dev, target=target).evaluate(func)(data) + np.testing.assert_allclose(op_res.numpy(), ref_res, rtol=rtol, atol=atol) + + for dtype, gamma, beta, axis, center, scale in itertools.product( + ["float16", "float32"], [3.0], [1.0], [1, 2], [True, False], [True, False] + ): + if dtype == "float16": + # Float16 version is a lot less accurate + rtol = 0.1 + atol = 0.1 + else: + rtol = 1e-5 + atol = 1e-3 + verify_layer_norm( + (1, 10, 10), + dtype, + gamma, + beta, + axis=axis, + center=center, + scale=scale, + rtol=rtol, + atol=atol, + ) + + @tvm.testing.uses_gpu def test_expand_dims(): # based on topi test