Skip to content
Closed
6 changes: 4 additions & 2 deletions include/tvm/topi/nn/layer_norm.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ using namespace tvm::te;
/*!
* \brief Layer normalization.
* \param data N-D tensor with shape [d_0, d_1, ..., d_{N-1}]
* \param gamma K-D tensor with shape [r_0, r_1, ..., r_{K-1}] where K == len(axis) and
* \param gamma Optional, K-D tensor with shape [r_0, r_1, ..., r_{K-1}] where K == len(axis) and
* d_{axis_k} == r_k
* \param beta Optional, K-D tensor with shape [r_0, r_1, ..., r_{K-1}] where
* d_{axis_k} == r_k
Expand Down Expand Up @@ -101,7 +101,9 @@ inline Tensor layer_norm(const Tensor& data, const Tensor& gamma, const Tensor&
auto mean = temp_x(non_reduce_indices) / reduce_extent;
auto var = temp_x2(non_reduce_indices) / reduce_extent - mean * mean;
auto layer_norm = (data(indices) - mean) * tvm::rsqrt(var + make_const(var->dtype, epsilon));
layer_norm = topi::multiply(layer_norm, gamma(reduce_indices));
if (gamma.defined()) {
layer_norm = topi::multiply(layer_norm, gamma(reduce_indices));
}
if (beta.defined()) {
layer_norm = topi::add(layer_norm, beta(reduce_indices));
}
Expand Down
3 changes: 3 additions & 0 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@
# log_softmax
reg.register_strategy("nn.log_softmax", strategy.log_softmax_strategy)

# layer_norm
reg.register_strategy("nn.layer_norm", strategy.layer_norm_strategy)


@reg.register_legalize("nn.matmul")
def legalize_matmul(attrs, inputs, types):
Expand Down
11 changes: 11 additions & 0 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,17 @@ def log_softmax_strategy_cuda(attrs, inputs, out_type, target):
return strategy


@layer_norm_strategy.register(["cuda", "gpu"])
def layer_norm_strategy_cuda(attrs, inputs, out_type, target):
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_layer_norm_strategy(topi.nn.layer_norm),
wrap_topi_schedule(topi.cuda.schedule_injective),
name="layer_norm.cuda",
)
return strategy


@schedule_lrn.register(["cuda", "gpu"])
def schedule_lrn_cuda(attrs, outs, target):
"""schedule LRN for cuda"""
Expand Down
30 changes: 30 additions & 0 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,36 @@ def log_softmax_strategy(attrs, inputs, out_type, target):
return strategy


def wrap_layer_norm_strategy(topi_compute):
"""Wrap softmax topi compute"""

def _compute_layer_norm(attrs, inputs, out_type):
axis = attrs.axis
epsilon = attrs.epsilon
return [
topi_compute(
inputs[0],
inputs[1] if attrs.scale else None,
inputs[2] if attrs.center else None,
[axis],
epsilon,
)
]

return _compute_layer_norm


@override_native_generic_func("layer_norm_strategy")
def layer_norm_strategy(attrs, inputs, out_type, target):
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_layer_norm_strategy(topi.nn.layer_norm),
wrap_topi_schedule(topi.generic.schedule_injective),
name="layer_norm.generic",
)
return strategy


# lrn
@generic_func
def schedule_lrn(attrs, outs, target):
Expand Down
3 changes: 2 additions & 1 deletion src/relay/op/nn/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1024,7 +1024,8 @@ RELAY_REGISTER_OP("nn.layer_norm")
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
NormalizationInferCorrectLayout<LayerNormAttrs>)
.set_support_level(1)
.add_type_rel("LayerNorm", LayerNormRel);
.add_type_rel("LayerNorm", LayerNormRel)
.set_attr<TOpPattern>("TOpPattern", kOutEWiseFusable);

// group_norm
TVM_REGISTER_NODE_TYPE(GroupNormAttrs);
Expand Down
27 changes: 0 additions & 27 deletions src/relay/transforms/simplify_inference.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,29 +115,6 @@ Expr GroupNormToInferUnpack(const Attrs attrs, Expr data, Expr gamma, Expr beta,
return out;
}

Expr LayerNormToInferUnpack(const Attrs attrs, Expr data, Expr gamma, Expr beta, Type tdata) {
auto ttype = tdata.as<TensorTypeNode>();
ICHECK(ttype);
const auto param = attrs.as<LayerNormAttrs>();
ICHECK(param);

Expr epsilon = MakeConstantScalar(ttype->dtype, static_cast<float>(param->epsilon));
Expr mean = Mean(data, {param->axis}, true, false);
Expr var = Variance(data, mean, {param->axis}, true, false);
Expr denom = Sqrt(Add(var, epsilon));
Expr out = Divide(Subtract(data, mean), denom);

size_t ndim = ttype->shape.size();
int axis = (param->axis < 0) ? param->axis + ndim : param->axis;
if (param->scale) {
out = Multiply(out, ExpandBiasToMatchAxis(gamma, ndim, {axis}));
}
if (param->center) {
out = Add(out, ExpandBiasToMatchAxis(beta, ndim, {axis}));
}
return out;
}

Expr InstanceNormToInferUnpack(const Attrs attrs, Expr data, Expr gamma, Expr beta, Type tdata) {
auto ttype = tdata.as<TensorTypeNode>();
ICHECK(ttype);
Expand Down Expand Up @@ -207,10 +184,6 @@ class InferenceSimplifier : public MixedModeMutator {
Expr Rewrite_(const CallNode* n, const Expr& new_n) {
if (n->op == batch_norm_op_) {
ty_map_[new_n.as<CallNode>()->args[0]] = n->args[0]->checked_type();
} else if (n->op == layer_norm_op_) {
const auto* call = new_n.as<CallNode>();
return LayerNormToInferUnpack(call->attrs, call->args[0], call->args[1], call->args[2],
n->args[0]->checked_type());
} else if (n->op == group_norm_op_) {
const auto* call = new_n.as<CallNode>();
return GroupNormToInferUnpack(call->attrs, call->args[0], call->args[1], call->args[2],
Expand Down
26 changes: 24 additions & 2 deletions src/tir/schedule/primitive/get_block_loop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,34 @@ Array<StmtSRef> GetChildBlocks(const ScheduleState& self, const StmtSRef& parent

Array<StmtSRef> GetProducers(const ScheduleState& self, const StmtSRef& block_sref) {
StmtSRef scope_root = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false);
return tir::GetProducers(block_sref, self->GetBlockScope(scope_root));
Array<Dependency> edges = self->GetBlockScope(scope_root)->GetDepsByDst(block_sref);
Array<StmtSRef> results;
std::unordered_set<StmtSRef, ObjectPtrHash, ObjectPtrEqual> result_set;
results.reserve(edges.size());
for (const Dependency& edge : edges) {
if ((edge->kind == DepKind::kRAW || edge->kind == DepKind::kWAW) &&
!result_set.count(edge->src)) {
results.push_back(edge->src);
result_set.emplace(edge->src);
}
}
return results;
}

Array<StmtSRef> GetConsumers(const ScheduleState& self, const StmtSRef& block_sref) {
StmtSRef scope_root = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false);
return tir::GetConsumers(block_sref, self->GetBlockScope(scope_root));
Array<Dependency> edges = self->GetBlockScope(scope_root)->GetDepsBySrc(block_sref);
Array<StmtSRef> results;
std::unordered_set<StmtSRef, ObjectPtrHash, ObjectPtrEqual> result_set;
results.reserve(edges.size());
for (const Dependency& edge : edges) {
if ((edge->kind == DepKind::kRAW || edge->kind == DepKind::kWAW) &&
!result_set.count(edge->dst)) {
results.push_back(edge->dst);
result_set.emplace(edge->dst);
}
}
return results;
}

/******** InstructionKind Registration ********/
Expand Down
52 changes: 52 additions & 0 deletions tests/python/relay/test_op_level1.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import itertools
from math import gamma
import numpy as np
import pytest
import tvm
Expand Down Expand Up @@ -172,6 +174,56 @@ def check_binary_op(opfunc, ref, dtype):
check_binary_op(opfunc, ref, dtype)


@tvm.testing.uses_gpu
def test_layer_norm():
from tvm.topi.testing import layer_norm_python

# based on topi test
def verify_layer_norm(dshape, dtype, gamma, beta, axis, center, scale, rtol, atol):
x = relay.Var("x", relay.TensorType(dshape, dtype))
func = relay.Function(
[x],
relay.nn.layer_norm(
x,
relay.const([gamma] * dshape[axis], dtype=dtype),
relay.const([beta] * dshape[axis], dtype=dtype),
axis=axis,
center=center,
scale=scale,
),
)
for target, dev in tvm.testing.enabled_targets():
if "cuda" in target or "nvptx" in target:
# CUDA needs tuning to work
continue
data = np.random.uniform(size=dshape).astype(dtype)
ref_res = layer_norm_python(data, gamma if scale else 1, beta if center else 0, axis)
op_res = relay.create_executor("graph", device=dev, target=target).evaluate(func)(data)
np.testing.assert_allclose(op_res.numpy(), ref_res, rtol=rtol, atol=atol)

for dtype, gamma, beta, axis, center, scale in itertools.product(
["float16", "float32"], [3.0], [1.0], [1, 2], [True, False], [True, False]
):
if dtype == "float16":
# Float16 version is a lot less accurate
rtol = 0.1
atol = 0.1
else:
rtol = 1e-5
atol = 1e-3
verify_layer_norm(
(1, 10, 10),
dtype,
gamma,
beta,
axis=axis,
center=center,
scale=scale,
rtol=rtol,
atol=atol,
)


@tvm.testing.uses_gpu
def test_expand_dims():
# based on topi test
Expand Down