diff --git a/nnvm/include/nnvm/top/nn.h b/nnvm/include/nnvm/top/nn.h index bbdb3b9c4f12..0e0b03a72ebb 100644 --- a/nnvm/include/nnvm/top/nn.h +++ b/nnvm/include/nnvm/top/nn.h @@ -368,6 +368,41 @@ struct NMSParam : public dmlc::Parameter { } }; +struct LRNParam : public dmlc::Parameter { + int size; + int axis; + float alpha; + float beta; + float bias; + + DMLC_DECLARE_PARAMETER(LRNParam) { + DMLC_DECLARE_FIELD(size) + .describe("The size of the local region to be considered for normalization."); + DMLC_DECLARE_FIELD(axis) + .describe("input data layout channel axis"); + DMLC_DECLARE_FIELD(alpha) + .describe("The scaling parameter."); + DMLC_DECLARE_FIELD(beta) + .describe("The exponent parameter."); + DMLC_DECLARE_FIELD(bias) + .describe("The offset parameter."); + } + // constants + static const constexpr int kData = 0; +}; + +struct L2NormalizeParam : public dmlc::Parameter { + float eps; + Tuple axis; + + DMLC_DECLARE_PARAMETER(L2NormalizeParam) { + DMLC_DECLARE_FIELD(eps) + .describe("float type epsilon value."); + DMLC_DECLARE_FIELD(axis) + .describe("axis over the normalization applied"); + } +}; + } // namespace top } // namespace nnvm diff --git a/nnvm/python/nnvm/top/nn.py b/nnvm/python/nnvm/top/nn.py index b7e0d0952888..5bfabdac2c8d 100644 --- a/nnvm/python/nnvm/top/nn.py +++ b/nnvm/python/nnvm/top/nn.py @@ -243,3 +243,36 @@ def schedule_upsampling(_, outs, target): return topi.generic.schedule_injective(outs) reg.register_pattern("upsampling", OpPattern.INJECTIVE) + +@reg.register_compute("lrn") +def compute_lrn(attrs, inputs, _): + """Compute definition of lrn""" + size = attrs.get_int("size") + axis = attrs.get_int("axis") + alpha = attrs.get_float("alpha") + beta = attrs.get_float("beta") + bias = attrs.get_float("bias") + return topi.nn.lrn(inputs[0], size, axis, alpha, beta, bias) + +@reg.register_schedule("lrn") +def schedule_lrn(attrs, outs, target): + """Schedule definition of lrn""" + with tvm.target.create(target): + return topi.generic.schedule_lrn(outs) + +reg.register_pattern("lrn", OpPattern.OPAQUE) + +@reg.register_compute("l2_normalize") +def compute_l2_normalize(attrs, inputs, _): + """Compute definition of l2 normalize""" + eps = attrs.get_float("eps") + axis = attrs.get_int_tuple("axis") + return topi.nn.l2_normalize(inputs[0], eps, axis) + +@reg.register_schedule("l2_normalize") +def schedule_l2_normalize(attrs, outs, target): + """Schedule definition of l2 normalize""" + with tvm.target.create(target): + return topi.generic.schedule_l2_normalize(outs) + +reg.register_pattern("l2_normalize", OpPattern.OUT_ELEMWISE_FUSABLE) diff --git a/nnvm/src/top/nn/nn.cc b/nnvm/src/top/nn/nn.cc index cedfb210855e..ab47ae521224 100644 --- a/nnvm/src/top/nn/nn.cc +++ b/nnvm/src/top/nn/nn.cc @@ -712,5 +712,52 @@ the input array by output[n, c, h, w, C] = data[n, C*16+c, h, w] }) .set_support_level(1); +DMLC_REGISTER_PARAMETER(LRNParam); + +inline bool LRNInferShape(const nnvm::NodeAttrs& attrs, + std::vector* in_shape, + std::vector* out_shape) { + TShape dshape = (*in_shape)[0]; + TShape oshape = dshape; + + NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, 0, oshape); + return true; +} + +NNVM_REGISTER_OP(lrn) +.describe(R"code(LRN layer)code" NNVM_ADD_FILELINE) +.add_argument("data", "4D Tensor", "Input data.") +.set_attr_parser(ParamParser) +.set_attr("FGetAttrDict", ParamGetAttrDict) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr("FInferShape", LRNInferShape) +.set_attr("FInferType", ElemwiseType<1, 1>) +.set_support_level(1); + +DMLC_REGISTER_PARAMETER(L2NormalizeParam); + +inline bool L2NormalizeInferShape(const nnvm::NodeAttrs& attrs, + std::vector* in_shape, + std::vector* out_shape) { + TShape dshape = (*in_shape)[0]; + TShape oshape = dshape; + + NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, 0, oshape); + return true; +} + +NNVM_REGISTER_OP(l2_normalize) +.describe(R"code(L2NORMALIZE layer)code" NNVM_ADD_FILELINE) +.add_argument("data", "4D Tensor", "Input data.") +.set_attr_parser(ParamParser) +.set_attr("FGetAttrDict", ParamGetAttrDict) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr("FInferShape", L2NormalizeInferShape) +.set_attr("FInferType", ElemwiseType<1, 1>) +.set_attr("FCorrectLayout", ElemwiseArbitraryLayout<1, 1>) +.set_support_level(1); + } // namespace top } // namespace nnvm diff --git a/nnvm/tests/python/compiler/test_top_level1.py b/nnvm/tests/python/compiler/test_top_level1.py index 3058d6ccfc7b..37798d37f400 100644 --- a/nnvm/tests/python/compiler/test_top_level1.py +++ b/nnvm/tests/python/compiler/test_top_level1.py @@ -6,7 +6,6 @@ import nnvm.compiler from nnvm.testing.config import ctx_list - def helper(symbol, inputs, dtype, np_forward, np_backward=None, need_input=True, need_head_grads=True): ishapes = {} @@ -365,6 +364,65 @@ def forward(x): inputs = [('x', (1, 3, 28, 28), x)] helper(y, inputs, dtype, forward) +def verify_lrn(ishape, size, axis, bias, alpha, beta): + x = sym.Variable("x") + y = sym.lrn(x, size=size, axis=axis, bias=bias, alpha=alpha, beta=beta) + dtype = "float32" + x_np = np.random.uniform(size=ishape).astype(dtype) + + for target, ctx in ctx_list(): + graph, lib, _ = nnvm.compiler.build(y, target, {"x": ishape}) + m = graph_runtime.create(graph, lib, ctx) + m.run(x=x_np) + out = m.get_output(0, tvm.nd.empty(ishape)) + out_np = topi.testing.lrn_python(x_np, size, axis, bias, alpha, beta) + np.testing.assert_allclose(out.asnumpy(), out_np, atol=1e-5, rtol=1e-5) + + #Checking LRN op followed by elementwise op relu + z = sym.relu(y) + x_np = np.random.uniform(low=-10.0, high=10.0, size=ishape).astype(dtype) + for target, ctx in ctx_list(): + graph, lib, _ = nnvm.compiler.build(z, target, {"x": ishape}) + m = graph_runtime.create(graph, lib, ctx) + m.run(x=x_np) + out = m.get_output(0, tvm.nd.empty(ishape)) + out_np = topi.testing.lrn_python(x_np, size, axis, bias, alpha, beta) + out_np = (out_np > 0) * out_np + np.testing.assert_allclose(out.asnumpy(), out_np, atol=1e-5, rtol=1e-5) + +def verify_l2_normalize(ishape, eps, axis): + x = sym.Variable("x") + y = sym.l2_normalize(x, eps=eps, axis=axis) + dtype = "float32" + x_np = np.random.uniform(size=ishape).astype(dtype) + + for target, ctx in ctx_list(): + graph, lib, _ = nnvm.compiler.build(y, target, {"x": ishape}) + m = graph_runtime.create(graph, lib, ctx) + m.run(x=x_np) + out = m.get_output(0, tvm.nd.empty(ishape)) + out_np = topi.testing.l2_normalize_python(x_np, eps, axis) + np.testing.assert_allclose(out.asnumpy(), out_np, atol=1e-5, rtol=1e-5) + + #Checking L2 normalization op followed by elementwise op relu + z = sym.relu(y) + x_np = np.random.uniform(low=-10.0, high=10.0, size=ishape).astype(dtype) + for target, ctx in ctx_list(): + graph, lib, _ = nnvm.compiler.build(z, target, {"x": ishape}) + m = graph_runtime.create(graph, lib, ctx) + m.run(x=x_np) + out = m.get_output(0, tvm.nd.empty(ishape)) + out_np = topi.testing.l2_normalize_python(x_np, eps, axis) + out_np = (out_np > 0) * out_np + np.testing.assert_allclose(out.asnumpy(), out_np, atol=1e-5, rtol=1e-5) + +def test_lrn(): + verify_lrn((1, 3, 20, 20), 3, 1, 1.0, 1.0, 0.5) + verify_lrn((1, 3, 20, 20), 3, 1, 2.0, 1.0, 0.75) + +def test_l2_normalize(): + verify_l2_normalize((1, 3, 20, 20), 0.001, (1,)) + verify_l2_normalize((1, 3, 20, 20), 0.001, (1, 2)) if __name__ == "__main__": test_split() @@ -384,3 +442,5 @@ def forward(x): test_softmax() test_squeeze() test_pad() + test_lrn() + test_l2_normalize() diff --git a/topi/include/topi/cuda/normalization.h b/topi/include/topi/cuda/normalization.h new file mode 100644 index 000000000000..91578c46d266 --- /dev/null +++ b/topi/include/topi/cuda/normalization.h @@ -0,0 +1,106 @@ +/*! +* Copyright (c) 2018 by Contributors +* \file cuda/normalization.h +* \brief CUDA schedule for LRN and l2 normalization operations +*/ +#ifndef TOPI_CUDA_NORMALIZATION_H_ +#define TOPI_CUDA_NORMALIZATION_H_ + +#include "tvm/tvm.h" +#include "tvm/build_module.h" +#include "topi/tags.h" + +namespace topi { +using namespace tvm; +namespace cuda { +/*! +* \brief Create a CUDA schedule for LRN +* +* \param target The target to generate a schedule for. +* \param outs The output tensors. +* +* \return A schedule for the given ops. +*/ +inline Schedule schedule_lrn(const Target &target, const Array& outs) { + Array out_ops; + for (auto t : outs) { + out_ops.push_back(t->op); + } + Schedule s = create_schedule(out_ops); + int num_thread = 64; + IterVar block_x = tvm::thread_axis(Range(), "blockIdx.x"); + IterVar thread_x = tvm::thread_axis(Range(0, num_thread), "threadIdx.x"); + Tensor lrn = outs[0]; + Tensor sqr_sum_up = lrn->op->InputTensors()[1]; + Tensor sqr_sum = sqr_sum_up->op->InputTensors()[0]; + Tensor set_pad = sqr_sum->op->InputTensors()[0]; + s[set_pad].bind(set_pad->op.as()->axis[0], block_x); + IterVar rxk = sqr_sum->op.as()->reduce_axis[0]; + IterVar xko, xki; + s[sqr_sum].split(rxk, num_thread, &xko, &xki); + Tensor srf = s.rfactor(sqr_sum, xki)[0]; + s[sqr_sum].bind(s[sqr_sum]->op.as()->axis[0], block_x); + s[sqr_sum].bind(s[sqr_sum]->op.as()->reduce_axis[0], thread_x); + s[srf].compute_at(s[sqr_sum], s[sqr_sum]->op.as()->reduce_axis[0]); + s[sqr_sum_up].bind(sqr_sum_up->op.as()->axis[0], block_x); + IterVar xto, xti; + s[lrn].split_by_nparts(lrn->op.as()->axis[1], num_thread, &xto, &xti); + s[lrn].bind(lrn->op.as()->axis[0], block_x); + s[lrn].bind(xto, thread_x); + + return s; +} + +/*! +* \brief Create a CUDA schedule for L2 normalization +* +* \param target The target to generate a schedule for. +* \param outs The output tensors. +* +* \return A schedule for the given ops. +*/ +inline Schedule schedule_l2_normalize(const Target &target, const Array& outs) { + Array out_ops; + for (auto t : outs) { + out_ops.push_back(t->op); + } + Schedule s = create_schedule(out_ops); + + std::function traverse; + traverse = [&](const Operation& op) { + // Inline all one-to-one-mapping operators except the last stage (output) + if (is_injective(op->tag) || op->tag == "l2_normalize") { + if (!detail::contains(s->outputs, op)) { + s[op].compute_inline(); + } + for (auto tensor : op->InputTensors()) { + if (tensor->op->InputTensors().size() > 0) { + traverse(tensor->op); + } + } + } else if (op->tag == "comm_reduce") { + ScheduleReduce(target, op, s, false); + for (auto tensor : op->InputTensors()) { + traverse(tensor->op); + } + } else { + LOG(ERROR) << "Unsupported operator " << op->tag; + } + }; + + traverse(outs[0]->op); + int num_thread = 64; + Tensor l2_normalize = outs[0]; + IterVar block_x = tvm::thread_axis(Range(), "blockIdx.x"); + IterVar thread_x = tvm::thread_axis(Range(0, num_thread), "threadIdx.x"); + IterVar xto, xti; + s[l2_normalize].split_by_nparts(l2_normalize->op.as()->axis[1], + num_thread, &xto, &xti); + s[l2_normalize].bind(l2_normalize->op.as()->axis[0], block_x); + s[l2_normalize].bind(xto, thread_x); + return s; +} +} // namespace cuda +} // namespace topi +#endif // TOPI_CUDA_NORMALIZATION_H_ + diff --git a/topi/include/topi/nn/l2_normalize.h b/topi/include/topi/nn/l2_normalize.h new file mode 100644 index 000000000000..079c6d467561 --- /dev/null +++ b/topi/include/topi/nn/l2_normalize.h @@ -0,0 +1,46 @@ +/*! + * Copyright (c) 2018 by Contributors + * \brief l2 normalization op constructions + * \file nn/l2_normalize.h + */ +#ifndef TOPI_NN_L2_NORMALIZE_H_ +#define TOPI_NN_L2_NORMALIZE_H_ + +#include +#include +#include "topi/tags.h" +#include "tvm/tvm.h" +namespace topi { +namespace nn { +using namespace tvm; + +/*! +* \brief L2 normalization inference operator +* +* \param data The input tensor. 4-D with shape [batch, channel, height, width] +* \param eps Epsilon to prevent div by 0 +* \param axis Axes over the normalization applied +* \param name The name of the operation +* \param tag The tag to mark the operation +* +* \return A Tensor whose op member is the l2 normalization operation +*/ +inline Tensor l2_normalize(const Tensor& data, + float eps, + const Array& axis, + std::string name = "tensor", + std::string tag = "l2_normalize") { + CHECK_EQ(data->shape.size(), 4) << "L2 normalization requires 4-D input"; + auto input_shape = data->shape; + Tensor dot_value = pow(data, static_cast(2.0)); + Tensor sum_value = topi::sum(dot_value, axis, true); + Tensor expand_sum = topi::broadcast_to(sum_value, input_shape); + return topi::broadcast_div(data, + topi::sqrt(tvm::compute(expand_sum->shape, + [&](const Array& i){ + return (max(expand_sum(i), eps)); + }, name = name, tag = tag))); +} +} // namespace nn +} // namespace topi +#endif // TOPI_NN_L2_NORMALIZE_H_ diff --git a/topi/include/topi/nn/local_response_norm.h b/topi/include/topi/nn/local_response_norm.h new file mode 100644 index 000000000000..c956a9c253dc --- /dev/null +++ b/topi/include/topi/nn/local_response_norm.h @@ -0,0 +1,76 @@ +/*! + * Copyright (c) 2018 by Contributors + * \brief local response normalization op constructions + * \file nn/local_response_norm.h + */ +#ifndef TOPI_NN_LOCAL_RESPONSE_NORM_H_ +#define TOPI_NN_LOCAL_RESPONSE_NORM_H_ + +#include + +#include "topi/tags.h" +#include "tvm/tvm.h" + +namespace topi { +namespace nn { +using namespace tvm; + +/*! +* \brief Local response normalization inference operator +* +* \param data The input tensor. 4-D shape NCHW or NHWC +* \param size Integer to define normalisation window size +* \param axis Input data layout channel axis +* \param alpha Float scaling factor +* \param beta Exponent value +* \param bias Offset to avoid dividing by zero +* \param name The name of the operation +* \param tag The tag to mark the operation +* +* \return A Tensor whose op member is the Local response normalization operation +*/ +inline Tensor lrn(const Tensor& data, + int size, + int axis = 1, + float alpha = 0.0001, + float beta = 0.75, + float bias = 2, + std::string name = "tensor", + std::string tag = kBroadcast) { + CHECK_EQ(data->shape.size(), 4) << "LRN requires 4-D input"; + CHECK_EQ(size % 2, 1) << "size should be odd number"; + CHECK(axis == 1 || axis == 3) << "axis should be 1 or 3 for NCHW and NHWC"; + auto input_shape = data->shape; + Array pad_before{ 0, 0, 0, 0}; + Array pad_after{ 0, 0, 0, 0}; + pad_before.Set(axis, static_cast(size/2)); + pad_after.Set(axis, static_cast(size/2)); + auto pad_data = pad(data, pad_before, pad_after, 0, "pad_data"); + auto rxs = tvm::reduce_axis(Range(0, size), "rxs"); + Tensor sqr_sum; + if (axis == 1) { + sqr_sum = tvm::compute(input_shape, + [&](Var i, Var l, Var j, Var k) { + return tvm::sum(pad_data(i, l + rxs, j, k) * + pad_data(i, l + rxs, j, k), + {rxs}); + }); + } else if (axis == 3) { + sqr_sum = tvm::compute(input_shape, + [&](Var i, Var l, Var j, Var k) { + return tvm::sum(pad_data(i, l, j, k + rxs) * + pad_data(i, l, j, k + rxs), + {rxs}); + }); + } + auto sqrt_sum_up = tvm::compute(input_shape, + [&](Var i, Var j, Var k, Var l) { + return tvm::pow(bias + + (alpha * sqr_sum(i, j, k, l) / size), + beta); + }); + return topi::broadcast_div(data, sqrt_sum_up); +} +} // namespace nn +} // namespace topi +#endif // TOPI_NN_LOCAL_RESPONSE_NORM_H_ diff --git a/topi/include/topi/rocm/normalization.h b/topi/include/topi/rocm/normalization.h new file mode 100644 index 000000000000..b12e64aba963 --- /dev/null +++ b/topi/include/topi/rocm/normalization.h @@ -0,0 +1,41 @@ +/*! +* Copyright (c) 2018 by Contributors +* \file rocm/normalization.h +* \brief rocm schedule for LRN and l2 normalization operations +*/ +#ifndef TOPI_ROCM_NORMALIZATION_H_ +#define TOPI_ROCM_NORMALIZATION_H_ + +#include "tvm/tvm.h" +#include "tvm/build_module.h" +#include "topi/tags.h" + +namespace topi { +using namespace tvm; +namespace rocm { +/*! +* \brief Create a rocm schedule for LRN +* +* \param target The target to generate a schedule for. +* \param outs The output tensors. +* +* \return A schedule for the given ops. +*/ +inline Schedule schedule_lrn(const Target &target, const Array& outs) { + return topi::cuda::schedule_lrn(target, outs); +} + +/*! +* \brief Create a rocm schedule for L2 Normalization +* +* \param target The target to generate a schedule for. +* \param outs The output tensors. +* +* \return A schedule for the given ops. +*/ +inline Schedule schedule_l2_normalize(const Target &target, const Array& outs) { + return topi::cuda::schedule_l2_normalize(target, outs); +} +} // namespace rocm +} // namespace topi +#endif // TOPI_ROCM_NORMALIZATION_H_ diff --git a/topi/python/topi/cuda/__init__.py b/topi/python/topi/cuda/__init__.py index 3b0e38c4d3f4..dbf00ebeb52b 100644 --- a/topi/python/topi/cuda/__init__.py +++ b/topi/python/topi/cuda/__init__.py @@ -17,4 +17,4 @@ from .extern import schedule_extern from .vision import schedule_region from .vision import schedule_reorg -from .nn import schedule_lrn, schedule_l2norm +from .nn import schedule_lrn, schedule_l2_normalize diff --git a/topi/python/topi/cuda/nn.py b/topi/python/topi/cuda/nn.py index e8757970505b..b503b2dad50f 100644 --- a/topi/python/topi/cuda/nn.py +++ b/topi/python/topi/cuda/nn.py @@ -4,8 +4,7 @@ import tvm from .. import generic -from .. import tag -from .reduction import _schedule_reduce +from .. import cpp @generic.schedule_lrn.register(["cuda"]) def schedule_lrn(outs): @@ -22,37 +21,18 @@ def schedule_lrn(outs): sch: Schedule The computation schedule for the op. """ - outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs - s = tvm.create_schedule([x.op for x in outs]) - num_thread = 64 - block_x = tvm.thread_axis("blockIdx.x") - thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x") + target = tvm.target.current_target(allow_none=False) + cpp_target = cpp.TEST_create_target(target.target_name) + return cpp.cuda.schedule_lrn(cpp_target, outs) - lrn = outs[0] - sqr_sum_up = lrn.op.input_tensors[1] - sqr_sum = sqr_sum_up.op.input_tensors[0] - set_pad = sqr_sum.op.input_tensors[0] - s[set_pad].bind(set_pad.op.axis[0], block_x) - rxk = sqr_sum.op.reduce_axis[0] - _, xki = s[sqr_sum].split(rxk, factor=num_thread) - srf = s.rfactor(sqr_sum, xki) - s[sqr_sum].bind(s[sqr_sum].op.axis[0], block_x) - s[sqr_sum].bind(s[sqr_sum].op.reduce_axis[0], thread_x) - s[srf].compute_at(s[sqr_sum], s[sqr_sum].op.reduce_axis[0]) - s[sqr_sum_up].bind(sqr_sum_up.op.axis[0], block_x) - xto, _ = s[lrn].split(lrn.op.axis[1], nparts=num_thread) - s[lrn].bind(lrn.op.axis[0], block_x) - s[lrn].bind(xto, thread_x) - return s - -@generic.schedule_l2norm.register(["cuda"]) -def schedule_l2norm(outs): - """Schedule for L2norm +@generic.schedule_l2_normalize.register(["cuda"]) +def schedule_l2_normalize(outs): + """Schedule for L2 normalize Parameters ---------- outs: Array of Tensor - The computation graph description of L2norm + The computation graph description of L2 normalize in the format of an array of tensors. Returns @@ -60,32 +40,6 @@ def schedule_l2norm(outs): sch: Schedule The computation schedule for the op. """ - outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs - s = tvm.create_schedule([x.op for x in outs]) - - def traverse(OP): - '''inline all one-to-one-mapping operators - except the last stage (output)''' - if tag.is_injective(OP.tag) or OP.tag == 'l2norm': - if OP not in s.outputs: - s[OP].compute_inline() - for tensor in OP.input_tensors: - if tensor.op.input_tensors: - traverse(tensor.op) - elif OP.tag == 'comm_reduce': - _schedule_reduce(OP, s, is_idx_reduce=False) - for tensor in OP.input_tensors: - traverse(tensor.op) - else: - raise RuntimeError("Unsupported operator tag: %s" % OP.tag) - traverse(outs[0].op) - - num_thread = 64 - l2norm = outs[0] - block_x = tvm.thread_axis("blockIdx.x") - thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x") - xto, _ = s[l2norm].split(l2norm.op.axis[1], nparts=num_thread) - s[l2norm].bind(l2norm.op.axis[0], block_x) - s[l2norm].bind(xto, thread_x) - - return s + target = tvm.target.current_target(allow_none=False) + cpp_target = cpp.TEST_create_target(target.target_name) + return cpp.cuda.schedule_l2_normalize(cpp_target, outs) diff --git a/topi/python/topi/generic/nn.py b/topi/python/topi/generic/nn.py index 5a16d12206a3..8f2f8612c7fa 100644 --- a/topi/python/topi/generic/nn.py +++ b/topi/python/topi/generic/nn.py @@ -2,7 +2,7 @@ """Generic nn operators""" from __future__ import absolute_import as _abs import tvm - +from .. import cpp def _default_schedule(outs, auto_inline): """Default schedule for llvm.""" @@ -273,17 +273,18 @@ def schedule_lrn(outs): sch: Schedule The computation schedule for the op. """ - return _default_schedule(outs, False) - + target = tvm.target.current_target(allow_none=False) + cpp_target = cpp.TEST_create_target(target.target_name) + return cpp.generic.default_schedule(cpp_target, outs, False) @tvm.target.generic_func -def schedule_l2norm(outs): - """Schedule for l2norm +def schedule_l2_normalize(outs): + """Schedule for l2 normalize Parameters ---------- outs: Array of Tensor - The computation graph description of l2norm + The computation graph description of l2 normalize in the format of an array of tensors. Returns @@ -291,4 +292,6 @@ def schedule_l2norm(outs): sch: Schedule The computation schedule for the op. """ - return _default_schedule(outs, False) + target = tvm.target.current_target(allow_none=False) + cpp_target = cpp.TEST_create_target(target.target_name) + return cpp.generic.default_schedule(cpp_target, outs, False) diff --git a/topi/python/topi/nn/__init__.py b/topi/python/topi/nn/__init__.py index 056d1a76339a..7b6ee4a86836 100644 --- a/topi/python/topi/nn/__init__.py +++ b/topi/python/topi/nn/__init__.py @@ -16,4 +16,4 @@ from .bnn import * from .upsampling import * from .local_response_norm import * -from .l2_norm import * +from .l2_normalize import * diff --git a/topi/python/topi/nn/l2_norm.py b/topi/python/topi/nn/l2_norm.py deleted file mode 100644 index 6b5381a85599..000000000000 --- a/topi/python/topi/nn/l2_norm.py +++ /dev/null @@ -1,35 +0,0 @@ -# pylint: disable=invalid-name -"""TVM operator for l2norm""" -from __future__ import absolute_import -import tvm -import topi - -@tvm.target.generic_func -def l2norm_instance(data, eps, axis=None): - """Perform L2norm on the input data - - For axis=None, y(i, j) = x(i, j) / sqrt(max(sum(x^2), eps)) - - Parameters - ---------- - data : tvm.Tensor - 4-D with NCHW or NHWC layout - - eps : float - epsilon value - - axis : list of int - axis over the normalization applied - - Returns - ------- - output : tvm.Tensor - 4-D output with same shape - """ - assert len(data.shape) == 4, "only support 4-dim lrn" - dot_value = topi.cpp.pow(data, 2.0) - sum_value = topi.sum(dot_value, axis=axis, keepdims=True) - expand_sum = topi.broadcast_to(sum_value, data.shape) - return topi.broadcast_div(data, topi.sqrt(\ - tvm.compute(expand_sum.shape, lambda i, j, k, l:\ - tvm.max(expand_sum[i, j, k, l], eps), tag='l2norm'))) diff --git a/topi/python/topi/nn/l2_normalize.py b/topi/python/topi/nn/l2_normalize.py new file mode 100644 index 000000000000..951084379eec --- /dev/null +++ b/topi/python/topi/nn/l2_normalize.py @@ -0,0 +1,29 @@ +# pylint: disable=invalid-name +"""TVM operator for l2 normalize""" +from __future__ import absolute_import +import tvm +from .. import cpp + +@tvm.target.generic_func +def l2_normalize(data, eps, axis=None): + """Perform L2 normalization on the input data + + For axis=None, y(i, j) = x(i, j) / sqrt(max(sum(x^2), eps)) + + Parameters + ---------- + data : tvm.Tensor + 4-D with NCHW or NHWC layout + + eps : float + epsilon value + + axis : list of int + axis over the normalization applied + + Returns + ------- + output : tvm.Tensor + 4-D output with same shape + """ + return cpp.nn.l2_normalize(data, eps, axis) diff --git a/topi/python/topi/nn/local_response_norm.py b/topi/python/topi/nn/local_response_norm.py index b44e02214acc..73eb41242513 100644 --- a/topi/python/topi/nn/local_response_norm.py +++ b/topi/python/topi/nn/local_response_norm.py @@ -2,8 +2,7 @@ """TVM operator for local response norm compute.""" from __future__ import absolute_import import tvm -import topi -from .pad import pad +from .. import cpp @tvm.target.generic_func def lrn(data, size, axis=1, alpha=0.0001, beta=0.75, bias=2): @@ -42,27 +41,4 @@ def lrn(data, size, axis=1, alpha=0.0001, beta=0.75, bias=2): output : tvm.Tensor 4-D output with same shape """ - assert len(data.shape) == 4, "only support 4-dim lrn" - assert (size % 2) == 1, "size should be odd number" - assert (axis == 1) or (axis == 3), "axis should 1 or 3 for NCHW and NHWC" - ##Add padding on left & right of size radius first - pad_after = pad_before = [0, 0, 0, 0] - pad_after[axis] = pad_before[axis] = (size//2) - pad_data = pad(data, pad_before, pad_after, name="pad_data") - - rxs = tvm.reduce_axis((0, size), name='rxs') - if axis == 1: - #NCHW layout - sqr_sum = tvm.compute(data.shape, lambda i, j, k, l: tvm.sum( - pad_data[i, j + rxs, k, l] * pad_data[i, j + rxs, k, l], - axis=rxs)) - elif axis == 3: - #NHWC layout - sqr_sum = tvm.compute(data.shape, lambda i, j, k, l: tvm.sum( - pad_data[i, j, k, l + rxs] * pad_data[i, j, k, l + rxs], - axis=rxs)) - - sqr_sum_up = tvm.compute(data.shape, lambda i, j, k, l: tvm.power( - (bias + (alpha * sqr_sum[i, j, k, l] / size)), beta)) - - return topi.broadcast_div(data, sqr_sum_up) + return cpp.nn.lrn(data, size, axis, alpha, beta, bias) diff --git a/topi/python/topi/rocm/nn.py b/topi/python/topi/rocm/nn.py index d9c529155f7b..5a9b2ad84db0 100644 --- a/topi/python/topi/rocm/nn.py +++ b/topi/python/topi/rocm/nn.py @@ -1,13 +1,18 @@ """scheduler for normalization functions on rocm backend""" from __future__ import absolute_import as _abs -import topi +import tvm from .. import generic +from .. import cpp @generic.schedule_lrn.register(["rocm", "gpu"]) def schedule_lrn(outs): - return topi.cuda.schedule_lrn(outs) + target = tvm.target.current_target(allow_none=False) + cpp_target = cpp.TEST_create_target(target.target_name) + return cpp.rocm.schedule_lrn(cpp_target, outs) -@generic.schedule_l2norm.register(["rocm", "gpu"]) -def schedule_l2norm(outs): - return topi.cuda.schedule_l2norm(outs) +@generic.schedule_l2_normalize.register(["rocm", "gpu"]) +def schedule_l2_normalize(outs): + target = tvm.target.current_target(allow_none=False) + cpp_target = cpp.TEST_create_target(target.target_name) + return cpp.rocm.schedule_l2_normalize(cpp_target, outs) diff --git a/topi/python/topi/testing/__init__.py b/topi/python/topi/testing/__init__.py index 3731040e3a85..c91eea7958ea 100644 --- a/topi/python/topi/testing/__init__.py +++ b/topi/python/topi/testing/__init__.py @@ -16,3 +16,5 @@ from .reorg_python import reorg_python from .region_python import region_python from .shortcut_python import shortcut_python +from .lrn_python import lrn_python +from .l2_normalize_python import l2_normalize_python diff --git a/topi/python/topi/testing/l2_normalize_python.py b/topi/python/topi/testing/l2_normalize_python.py new file mode 100644 index 000000000000..98f1843233a7 --- /dev/null +++ b/topi/python/topi/testing/l2_normalize_python.py @@ -0,0 +1,27 @@ +# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals +"""L2 normalize in python""" +import numpy as np + +def l2_normalize_python(a_np, eps, axis=None): + """L2 normalize operator in NCHW layout. + + Parameters + ---------- + a_np : numpy.ndarray + 4-D with shape [batch, in_channel, in_height, in_width] + + eps : float + epsilon constant value + axis : list of int + axis over the normalization applied + + Returns + ------- + l2_normalize_out : np.ndarray + 4-D with shape [batch, out_channel, out_height, out_width] + """ + dot_value = np.power(a_np, 2.0) + sqr_sum = np.sum(dot_value, axis, keepdims=True) + sqrt_sum = np.sqrt(np.maximum(np.broadcast_to(sqr_sum, a_np.shape), eps)) + l2_normalize_out = np.divide(a_np, sqrt_sum) + return l2_normalize_out diff --git a/topi/python/topi/testing/lrn_python.py b/topi/python/topi/testing/lrn_python.py new file mode 100644 index 000000000000..4e44e8bcb635 --- /dev/null +++ b/topi/python/topi/testing/lrn_python.py @@ -0,0 +1,53 @@ +# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals +"""LRN in python""" +from itertools import product +import numpy as np + +def lrn_python(a_np, size, axis, bias, alpha, beta): + """Local response normalization operator in NCHW layout. + + Parameters + ---------- + a_np : numpy.ndarray + 4-D with shape [batch, in_channel, in_height, in_width] + + size : int + normalization window size + + axis : int + input data layout channel axis + + bias : float + offset to avoid dividing by 0. constant value + + alpha : float + constant value + + beta : float + exponent constant value + + Returns + ------- + lrn_out : np.ndarray + 4-D with shape [batch, out_channel, out_height, out_width] + """ + radius = size // 2 + sqr_sum = np.zeros(shape=a_np.shape).astype(a_np.dtype) + for i, j, k, l in product(*[range(_axis) for _axis in a_np.shape]): + axis_size = a_np.shape[axis] + if axis == 1: + #NCHW layout + sum_start = j-radius if j-radius >= 0 else 0 + sum_end = j+radius+1 if j+radius+1 < axis_size else axis_size + sqr_sum[i, j, k, l] = sum(a_np[i, sum_start:sum_end, k, l] * \ + a_np[i, sum_start:sum_end, k, l]) + elif axis == 3: + #NHWC layout + sum_start = l-radius if l-radius >= 0 else 0 + sum_end = l+radius+1 if l+radius+1 < axis_size else axis_size + sqr_sum[i, j, k, l] = sum(a_np[i, j, k, sum_start:sum_end] * \ + a_np[i, j, k, sum_start:sum_end]) + + sqr_sum_up = np.power((bias + (alpha * sqr_sum /size)), beta) + lrn_out = np.divide(a_np, sqr_sum_up) + return lrn_out diff --git a/topi/src/topi.cc b/topi/src/topi.cc index 4169f5f563ad..9f2ecacd11a4 100644 --- a/topi/src/topi.cc +++ b/topi/src/topi.cc @@ -24,6 +24,8 @@ #include #include #include +#include +#include #include #include @@ -39,6 +41,7 @@ #include #include #include +#include #include #include @@ -46,6 +49,7 @@ #include #include +#include namespace topi { @@ -359,6 +363,20 @@ TVM_REGISTER_GLOBAL("topi.nn.log_softmax") *rv = nn::log_softmax(args[0]); }); +/* Ops from nn/l2_normalize.h */ +TVM_REGISTER_GLOBAL("topi.nn.l2_normalize") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = nn::l2_normalize(args[0], static_cast(args[1]), args[2]); + }); + +TVM_REGISTER_GLOBAL("topi.nn.lrn") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = nn::lrn(args[0], args[1], args[2], + static_cast(args[3]), + static_cast(args[4]), + static_cast(args[5])); + }); + TVM_REGISTER_GLOBAL("topi.vision.reorg") .set_body([](TVMArgs args, TVMRetValue *rv) { *rv = vision::reorg(args[0], args[1]); @@ -435,6 +453,17 @@ TVM_REGISTER_GLOBAL("topi.rocm.schedule_region") .set_body([](TVMArgs args, TVMRetValue *rv) { *rv = topi::rocm::schedule_region(args[0], args[1]); }); + +TVM_REGISTER_GLOBAL("topi.rocm.schedule_lrn") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = topi::rocm::schedule_lrn(args[0], args[1]); + }); + +TVM_REGISTER_GLOBAL("topi.rocm.schedule_l2_normalize") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = topi::rocm::schedule_l2_normalize(args[0], args[1]); + }); + /* CUDA schedules */ TVM_REGISTER_GLOBAL("topi.cuda.dense_cuda") .set_body([](TVMArgs args, TVMRetValue *rv) { @@ -481,6 +510,16 @@ TVM_REGISTER_GLOBAL("topi.cuda.schedule_region") *rv = topi::cuda::schedule_region(args[0], args[1]); }); +TVM_REGISTER_GLOBAL("topi.cuda.schedule_lrn") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = topi::cuda::schedule_lrn(args[0], args[1]); + }); + +TVM_REGISTER_GLOBAL("topi.cuda.schedule_l2_normalize") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = topi::cuda::schedule_l2_normalize(args[0], args[1]); + }); + /*! \brief Builder function for instantiating schedules. */ using FTVMScheduleBuilder = std::function< tvm::Schedule(const tvm::Target& target, const tvm::Array& outs)>; diff --git a/topi/tests/python/test_topi_l2norm.py b/topi/tests/python/test_topi_l2norm.py index 182099ff9367..b27a1dc27e72 100644 --- a/topi/tests/python/test_topi_l2norm.py +++ b/topi/tests/python/test_topi_l2norm.py @@ -1,44 +1,18 @@ -"""Test code for L2 norm""" +"""Test code for L2 normalization""" import numpy as np import tvm import topi from topi.util import get_const_tuple +import topi.testing -def l2norm_instance_python(a_np, eps, axis=None): - """L2 norm operator in NCHW layout. +def verify_l2_normalize(ishape, eps, axis=None): - Parameters - ---------- - a_np : numpy.ndarray - 4-D with shape [batch, in_channel, in_height, in_width] - - eps : float - epsilon constant value - axis : list of int - axis over the normalization applied - - Returns - ------- - l2norm_out : np.ndarray - 4-D with shape [batch, out_channel, out_height, out_width] - """ - batch, axis1, axis2, axis3 = a_np.shape - sqr_sum = np.zeros(shape=(batch,)).astype(a_np.dtype) - sqrt_sum = np.zeros(shape=(batch,)).astype(a_np.dtype) - l2norm_out = np.zeros(shape=a_np.shape).astype(a_np.dtype) - dot_value = np.power(a_np, 2.0) - sqr_sum = np.sum(dot_value, axis, keepdims=True) - sqrt_sum = np.sqrt(np.maximum(np.broadcast_to(sqr_sum, a_np.shape), eps)) - return np.divide(a_np, sqrt_sum) - -def verify_l2norm(n, c, h, w, eps, axis=None): - - A = tvm.placeholder((n, c, h, w), name='A') - B = topi.nn.l2norm_instance(A, eps, axis) + A = tvm.placeholder(ishape, name='A') + B = topi.nn.l2_normalize(A, eps, axis) dtype = A.dtype - a_np = np.random.uniform(size=(n, c, h, w)).astype(dtype) - b_np = l2norm_instance_python(a_np, eps, axis) + a_np = np.random.uniform(size=ishape).astype(dtype) + b_np = topi.testing.l2_normalize_python(a_np, eps, axis) def check_device(device): ctx = tvm.context(device, 0) @@ -47,7 +21,10 @@ def check_device(device): return print("Running on target: %s" % device) with tvm.target.create(device): - s = topi.generic.schedule_l2norm(B) + if device == 'llvm': + s = topi.generic.schedule_l2_normalize([B]) + else: + s = topi.cuda.schedule_l2_normalize([B]) a = tvm.nd.array(a_np, ctx) b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), ctx) f = tvm.build(s, [A, B], device) @@ -57,14 +34,14 @@ def check_device(device): for device in ['llvm', 'cuda', 'opencl', 'metal', 'rocm', 'vulkan']: check_device(device) -def test_l2norm(): - verify_l2norm(1, 3, 20, 20, 0.001) - verify_l2norm(1, 3, 20, 20, 0.001, 1) - verify_l2norm(1, 3, 20, 20, 0.001, (1, 2)) - verify_l2norm(1, 3, 20, 20, 0.001, (2, 3)) - verify_l2norm(1, 3, 20, 20, 0.001, (0, 3)) - verify_l2norm(1, 3, 20, 20, 0.001, (0, 2, 3)) +def test_l2_normalize(): + verify_l2_normalize((1, 3, 20, 20), 0.001) + verify_l2_normalize((1, 3, 20, 20), 0.001, (1,)) + verify_l2_normalize((1, 3, 20, 20), 0.001, (1, 2)) + verify_l2_normalize((1, 3, 20, 20), 0.001, (2, 3)) + verify_l2_normalize((1, 3, 20, 20), 0.001, (0, 3)) + verify_l2_normalize((1, 3, 20, 20), 0.001, (0, 2, 3)) if __name__ == "__main__": - test_l2norm() + test_l2_normalize() diff --git a/topi/tests/python/test_topi_lrn.py b/topi/tests/python/test_topi_lrn.py index 596e5747a6c5..7d62aefe5f55 100644 --- a/topi/tests/python/test_topi_lrn.py +++ b/topi/tests/python/test_topi_lrn.py @@ -3,63 +3,7 @@ import tvm import topi from topi.util import get_const_tuple - -def lrn_python(a_np, size, axis, bias, alpha, beta): - """Local response norm operator in NCHW layout. - - Parameters - ---------- - a_np : numpy.ndarray - 4-D with shape [batch, in_channel, in_height, in_width] - - size : int - normalisation window size - - axis : int - input data layout channel axis - - bias : float - offset to avoid dividing by 0. constant value - - alpha : float - contant valie - - beta : float - exponent constant value - - Returns - ------- - b_np : np.ndarray - 4-D with shape [batch, out_channel, out_height, out_width] - """ - axis0, axis1, axis2, axis3 = a_np.shape - radius = size // 2 - sqr_sum = np.zeros(shape=a_np.shape).astype(a_np.dtype) - sqr_sum_up = np.zeros(shape=a_np.shape).astype(a_np.dtype) - lrn_out = np.zeros(shape=a_np.shape).astype(a_np.dtype) - def sum_dot_values(i, j, k, l): - axis_size = a_np.shape[axis] - if (axis == 1): - #NCHW layout - sum_start = j-radius if j-radius >= 0 else 0 - sum_end = j+radius+1 if j+radius+1 < axis_size else axis_size - sqr_sum[i, j, k, l] = sum(a_np[i, sum_start:sum_end, k, l] * \ - a_np[i, sum_start:sum_end, k, l]) - elif (axis == 3): - #NHWC layout - sum_start = l-radius if l-radius >= 0 else 0 - sum_end = l+radius+1 if l+radius+1 < axis_size else axis_size - sqr_sum[i, j, k, l] = sum(a_np[i, j, k, sum_start:sum_end] * \ - a_np[i, j, k, sum_start:sum_end]) - - for i in range(axis0): - for j in range(axis1): - for k in range(axis2): - for l in range(axis3): - sum_dot_values(i, j, k, l) - - sqr_sum_up = np.power((bias + (alpha * sqr_sum /size)), beta) - return np.divide(a_np, sqr_sum_up) +import topi.testing def verify_lrn(shape, size, axis, bias, alpha, beta): A = tvm.placeholder(shape, name='A') @@ -67,16 +11,19 @@ def verify_lrn(shape, size, axis, bias, alpha, beta): dtype = A.dtype a_np = np.random.uniform(size=shape).astype(dtype) - b_np = lrn_python(a_np, size, axis, bias, alpha, beta) + b_np = topi.testing.lrn_python(a_np, size, axis, bias, alpha, beta) def check_device(device): - ctx = tvm.context(device, 0) - if not ctx.exist: + if not tvm.module.enabled(device): print("Skip because %s is not enabled" % device) return print("Running on target: %s" % device) with tvm.target.create(device): - s = topi.generic.schedule_lrn(B) + if device == 'llvm': + s = topi.generic.schedule_lrn([B]) + else: + s = topi.cuda.schedule_lrn([B]) + ctx = tvm.context(device, 0) a = tvm.nd.array(a_np, ctx) b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), ctx) f = tvm.build(s, [A, B], device) @@ -87,9 +34,9 @@ def check_device(device): check_device(device) def test_lrn(): - verify_lrn((1, 3, 5, 5), 3, 1, 1, 1, 0.5) - verify_lrn((1, 3, 5, 5), 3, 3, 1, 1, 0.5) - verify_lrn((1, 3, 20, 20), 3, 1, 2, 1, 0.75) + verify_lrn((1, 3, 5, 5), 3, 1, 1.0, 1.0, 0.5) + verify_lrn((1, 3, 5, 5), 3, 3, 1.0, 1.0, 0.5) + verify_lrn((1, 3, 20, 20), 3, 1, 2.0, 1.0, 0.75) if __name__ == "__main__": test_lrn() diff --git a/topi/tests/python_cpp/test_topi_l2norm.py b/topi/tests/python_cpp/test_topi_l2norm.py new file mode 100644 index 000000000000..08799f76c5c3 --- /dev/null +++ b/topi/tests/python_cpp/test_topi_l2norm.py @@ -0,0 +1,48 @@ +"""Test code for l2 normalization""" +import numpy as np +import tvm +import topi +import logging +from topi.util import get_const_tuple +import topi.testing + +def verify_l2_normalize(shape, eps, axis=None): + '''Verify l2 normalization operator by comparing outputs from tvm and numpy implementation''' + A = tvm.placeholder(shape, name='A') + B = topi.cpp.nn.l2_normalize(A, eps, axis) + dtype = A.dtype + + a_np = np.random.uniform(size=shape).astype(dtype) + b_np = topi.testing.l2_normalize_python(a_np, eps, axis) + + def check_device(device): + if not tvm.module.enabled(device): + print("Skip because %s is not enabled" % device) + return + print("Running on target: %s" % device) + target = topi.cpp.TEST_create_target(device) + if device == "llvm": + s = topi.cpp.generic.default_schedule(target, [B], False) + else: + s = topi.cpp.cuda.schedule_l2_normalize(target, [B]) + ctx = tvm.context(device, 0) + a = tvm.nd.array(a_np, ctx) + b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx) + func = tvm.build(s, [A, B], device, name="l2_normalize") + func(a, b) + np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) + + for device in ['cuda', 'opencl', 'metal', 'rocm', 'llvm']: + check_device(device) + +def test_l2_normalize(): + verify_l2_normalize((1, 3, 20, 20), 0.001) + verify_l2_normalize((1, 3, 20, 20), 0.001, (1,)) + verify_l2_normalize((1, 3, 20, 20), 0.001, (1, 2)) + verify_l2_normalize((1, 3, 20, 20), 0.001, (2, 3)) + verify_l2_normalize((1, 3, 20, 20), 0.001, (0, 3)) + verify_l2_normalize((1, 3, 20, 20), 0.001, (0, 2, 3)) + +if __name__ == "__main__": + logging.basicConfig(level=logging.DEBUG) + test_l2_normalize() diff --git a/topi/tests/python_cpp/test_topi_lrn.py b/topi/tests/python_cpp/test_topi_lrn.py new file mode 100644 index 000000000000..d685643a9406 --- /dev/null +++ b/topi/tests/python_cpp/test_topi_lrn.py @@ -0,0 +1,44 @@ +"""Test code for LRN""" +import numpy as np +import tvm +import topi +import logging +from topi.util import get_const_tuple +import topi.testing + +def verify_lrn(shape, size, axis, bias, alpha, beta): + '''Verify Local response normalization operator by comparing outputs from tvm and numpy implementation''' + A = tvm.placeholder(shape, name='A') + B = topi.cpp.nn.lrn(A, size, axis, alpha, beta, bias) + dtype = A.dtype + + a_np = np.random.uniform(size=shape).astype(dtype) + b_np = topi.testing.lrn_python(a_np, size, axis, bias, alpha, beta) + def check_device(device): + if not tvm.module.enabled(device): + print("Skip because %s is not enabled" % device) + return + print("Running on target: %s" % device) + target = topi.cpp.TEST_create_target(device) + if device == "llvm": + s = topi.cpp.generic.default_schedule(target, [B], False) + else: + s = topi.cpp.cuda.schedule_lrn(target, [B]) + ctx = tvm.context(device, 0) + a = tvm.nd.array(a_np, ctx) + b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), ctx) + f = tvm.build(s, [A, B], device) + f(a, b) + np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-1) + + for device in ['cuda', 'opencl', 'metal', 'rocm', 'llvm']: + check_device(device) + +def test_lrn(): + verify_lrn((1, 3, 5, 5), 3, 3, 1.0, 1.0, 0.5) + verify_lrn((1, 3, 5, 5), 3, 3, 1.0, 1.0, 0.5) + verify_lrn((1, 3, 20, 20), 3, 1, 2.0, 1.0, 0.75) + +if __name__ == "__main__": + logging.basicConfig(level=logging.DEBUG) + test_lrn()