diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h index c3c58e54517c..c870d17ecaa2 100644 --- a/include/tvm/relay/attrs/nn.h +++ b/include/tvm/relay/attrs/nn.h @@ -1371,6 +1371,14 @@ struct BatchToSpaceNDAttrs : public tvm::AttrsNode { } }; // struct BatchToSpaceNDAttrs +struct EmbedAttrs : public tvm::AttrsNode { + TVM_DECLARE_ATTRS(EmbedAttrs, "relay.attrs.EmbedAttrs") {} +}; + +struct EmbedGradAttrs : public tvm::AttrsNode { + TVM_DECLARE_ATTRS(EmbedGradAttrs, "relay.attrs.EmbedGradAttrs") {} +}; + } // namespace relay } // namespace tvm #endif // TVM_RELAY_ATTRS_NN_H_ diff --git a/python/tvm/autotvm/task/topi_integration.py b/python/tvm/autotvm/task/topi_integration.py index f6ca3b179824..1de4e8d49df6 100644 --- a/python/tvm/autotvm/task/topi_integration.py +++ b/python/tvm/autotvm/task/topi_integration.py @@ -229,7 +229,11 @@ def wrapper(outs, *args, **kwargs): """wrapper function for topi schedule""" workload = get_workload(outs) if workload is None: - raise RuntimeError("Cannot find workload in attribute of this schedule") + raise RuntimeError( + f"Cannot find workload for {task_name}. You may need to " + "register a compute function for it with " + f'`@tvm.autotvm.register_topi_compute("{task_name}")`' + ) tgt = Target.current() cfg = DispatchContext.current.query(tgt, workload) return topi_schedule(cfg, outs, *args, **kwargs) diff --git a/python/tvm/relay/op/_tensor_grad.py b/python/tvm/relay/op/_tensor_grad.py index 5836aebce393..df6664ef3ccc 100644 --- a/python/tvm/relay/op/_tensor_grad.py +++ b/python/tvm/relay/op/_tensor_grad.py @@ -827,7 +827,6 @@ def arange_grad(orig, grad): return [grad_start, grad_stop, grad_step] - @register_gradient("gather_nd") def gather_nd_grad(orig, grad): """ @@ -866,3 +865,9 @@ def less_equal_grad(orig, grad): Returns the gradient of less_equal. """ return [zeros_like(orig.args[0]), zeros_like(orig.args[1])] + + +@register_gradient("nn.embed") +def embed_grad(orig, grad): + table, indices = orig.args + return [_nn.embed_grad(table, indices, grad), zeros_like(indices)] diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index 6ae86c0786e5..48faebdb5734 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -895,6 +895,27 @@ def compute_space_to_depth(attrs, inputs, out_dtype): reg.register_injective_schedule("nn.batch_to_space_nd") +# embed +@reg.register_compute("nn.embed") +def compute_embed(attrs, inputs, out_type): + """Compute definition of embed""" + return [topi.nn.embed(inputs[0], inputs[1])] + + +reg.register_injective_schedule("nn.embed") +reg.register_pattern("nn.embed", OpPattern.INJECTIVE) + + +@reg.register_compute("nn.embed_grad") +def compute_embed_grad(attrs, inputs, out_type): + """Compute definition of embed_grad""" + return [topi.nn.embed_grad(inputs[0], inputs[1], inputs[2])] + + +reg.register_strategy("nn.embed_grad", strategy.embed_grad_strategy) +reg.register_pattern("nn.embed_grad", OpPattern.OUT_ELEMWISE_FUSABLE) + + ##################### # Shape functions # ##################### diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index 5135ac74de25..cef8320f54a4 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -3307,3 +3307,61 @@ def batch_to_space_nd(data, block_shape, crops): """ return _make.batch_to_space_nd(data, block_shape, crops) + + +def embed(table, indices): + """Lookup indices in an embedding table. + + The embedding lookup is defined as: + + .. math:: + + O[i,j] = T[I[i],j] + + where :math:`T` is the embedding table, and :math:`I` is the indices to + lookup. This is specialization of take with two dimensional input and axis + = 0. + + + Parameters + ---------- + table : tvm.te.Tensor + M x N tensor of embedding locations. + indices : tvm.te.Tensor + Length K vector of indices to lookup in `table`. + + Returns + ------- + Output : tvm.te.Tensor + K x N tensor corresponding to the rows of `table` indexed with `indices`. + """ + return _make.embed(table, indices) + + +def embed_grad(table, indices, grad): + """Gradient of :py:func:`embed`. + + The gradient of an embedding lookup is defined as: + + .. math:: + + O[I[i],j] = G[i, j] + + where :math:`G` is the gradient, and :math:`I` is the indices to lookup. + + + Parameters + ---------- + table : tvm.te.Tensor + M x N tensor of embedding locations. + indices : tvm.te.Tensor + Length K vector of indices to lookup in `table`. + grad : tvm.te.Tensor + K x N tensor of the gradient to propagate. + + Returns + ------- + Output : tvm.te.Tensor + K x N tensor containing the propagated gradient. + """ + return _make.embed_grad(table, indices, grad) diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index cb4688c4889e..51cfe0658c39 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -1017,3 +1017,15 @@ def cumsum_strategy_cuda(attrs, inputs, out_type, target): name="cumsum.cuda", ) return strategy + + +@embed_grad_strategy.register(["cuda", "gpu"]) +def embed_grad_strategy_gpu(attrs, inputs, out_type, target): + """gpu strategy for embed_grad""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_embed_grad(topi.cuda.embed_grad), + wrap_topi_schedule(topi.cuda.schedule_embed_grad), + name="embed_grad.cuda", + ) + return strategy diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index f076176c5d8a..0ff6e457fc2f 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -1432,3 +1432,25 @@ def cumsum_strategy(attrs, inputs, out_type, target): name="cumsum.generic", ) return strategy + + +def wrap_compute_embed_grad(topi_compute): + """wrap embed_grad""" + + def _wrapped(attrs, inputs, out_type): + return [topi_compute(inputs[0], inputs[1], inputs[2])] + + return _wrapped + + +@override_native_generic_func("embed_grad_strategy") +def embed_grad_strategy(attrs, inputs, out_type, target): + """embed gradient generic strategy""" + logger.warning("embed_grad is not optimized for this platform.") + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_embed_grad(topi.nn.embed_grad), + wrap_topi_schedule(topi.generic.schedule_embed_grad), + name="embed_grad.generic", + ) + return strategy diff --git a/python/tvm/relay/op/strategy/x86.py b/python/tvm/relay/op/strategy/x86.py index 1f37a4f8e98c..493a81025375 100644 --- a/python/tvm/relay/op/strategy/x86.py +++ b/python/tvm/relay/op/strategy/x86.py @@ -569,3 +569,15 @@ def conv2d_winograd_without_weight_transfrom_strategy_cpu(attrs, inputs, out_typ "Unsupported conv2d_winograd_without_weight_transfrom layout {}".format(layout) ) return strategy + + +@embed_grad_strategy.register("cpu") +def embed_grad_strategy_cpu(attrs, inputs, out_type, target): + """x86 strategy for embed_grad""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_embed_grad(topi.nn.embed_grad), + wrap_topi_schedule(topi.x86.schedule_embed_grad), + name="embed_grad.x86", + ) + return strategy diff --git a/python/tvm/te/hybrid/calls.py b/python/tvm/te/hybrid/calls.py index 462066106a9d..31c4f46e9072 100644 --- a/python/tvm/te/hybrid/calls.py +++ b/python/tvm/te/hybrid/calls.py @@ -137,6 +137,11 @@ def _cast(func_id, args): return _expr.Cast(func_id, args[0]) +def cast(_, args): + _internal_assert(args.__len__() == 2, "cast requires two arguments: dtype, value") + return _expr.Cast(args[0], args[1]) + + float16 = float32 = float64 = _cast # pylint: disable=invalid-name int8 = int16 = int32 = int64 = _cast # pylint: disable=invalid-name uint8 = uint16 = uint32 = uint64 = _cast # pylint: disable=invalid-name diff --git a/python/tvm/te/hybrid/runtime.py b/python/tvm/te/hybrid/runtime.py index 615bd7e43a7d..9541e4692a78 100644 --- a/python/tvm/te/hybrid/runtime.py +++ b/python/tvm/te/hybrid/runtime.py @@ -119,6 +119,11 @@ def ninf(dtype): return numpy.iinfo(dtype).min +def cast(x, dtype): + """Convert `x` to `dtype`.""" + return getattr(numpy, dtype)(x) + + HYBRID_GLOBALS = { "unroll": range, "vectorize": range, @@ -150,8 +155,13 @@ def ninf(dtype): "float64": numpy.float64, "ceil_div": lambda a, b: (a + b - 1) // b, "max_num_threads": max_num_threads, +<<<<<<< HEAD "inf": inf, "ninf": inf, +||||||| parent of 57e650690... [TOPI] Add embed op and gradient. +======= + "cast": cast, +>>>>>>> 57e650690... [TOPI] Add embed op and gradient. } diff --git a/python/tvm/testing.py b/python/tvm/testing.py index d65ab23677b5..1591dc75b04e 100644 --- a/python/tvm/testing.py +++ b/python/tvm/testing.py @@ -745,4 +745,33 @@ def terminate_self(): sys.exit(-1) +def compare_numpy_tvm(inputs, output, target, ctx, compute, schedule): + """Compare a numpy inputs and output of a function to the results of the TVM version. + + Parameters + ---------- + inputs : Sequence[numpy.nd.array] + List of input numpy arrays to pass to the function. + output : numpy.nd.array + Verified correct function output. + target : tvm.target.Target + Target to run on. + ctx : tvm.TVMContext + Context to run on. + compute : callable + Topi compute function to test against. + schedule : callable + Topi scheduling function to test against. + """ + te_inputs = [tvm.te.placeholder(shape=i.shape, dtype=str(i.dtype)) for i in inputs] + te_out = tvm.nd.array(np.zeros(output.shape).astype(output.dtype), ctx=ctx) + with tvm.target.Target(target): + out = compute(*te_inputs) + s = schedule([out]) + func = tvm.build(s, te_inputs + [out]) + arys = [tvm.nd.array(x, ctx=ctx) for x in inputs] + func(*(arys + [te_out])) + assert_allclose(output, te_out.asnumpy(), atol=1e-4, rtol=1e-4) + + tvm._ffi._init_api("testing", __name__) diff --git a/python/tvm/topi/cuda/__init__.py b/python/tvm/topi/cuda/__init__.py index bf3582c01d4f..bd86247ca167 100644 --- a/python/tvm/topi/cuda/__init__.py +++ b/python/tvm/topi/cuda/__init__.py @@ -40,7 +40,7 @@ from .injective import schedule_injective, schedule_elemwise, schedule_broadcast from .dense import * from .pooling import * -from .nn import schedule_lrn +from .nn import schedule_lrn, schedule_embed_grad from .batch_matmul import * from .batch_matmul_tensorcore import * from .vision import * diff --git a/python/tvm/topi/cuda/nn.py b/python/tvm/topi/cuda/nn.py index 0de377705531..c6820946689f 100644 --- a/python/tvm/topi/cuda/nn.py +++ b/python/tvm/topi/cuda/nn.py @@ -18,6 +18,7 @@ """scheduler functions for cuda backend""" from __future__ import absolute_import as _abs +from tvm import te from .. import cpp @@ -36,3 +37,43 @@ def schedule_lrn(outs): The computation schedule for the op. """ return cpp.cuda.schedule_lrn(outs) + + +def loads_per_thread(dtype): + """Number elements per load per thread""" + s = regex.search("[0-9]+", dtype) + assert s is not None + byts = int(s.group()) // 8 + return 16 // byts + + +def schedule_embed_grad(outs): + """Schedule for embed_grad + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of embed_grad + in the format of an array of tensors. + + Returns + ------- + sch: Schedule + The computation schedule for the op. + """ + s = te.create_schedule([outs[0].op]) + # this should be autotuned, but we can't with hybrid script + vec_size = loads_per_thread(outs[0].dtype) + warp_size = int(tvm.target.Target.current(allow_none=False).thread_warp_size) + num_warps = 4 + out = s.outputs[0].output(0) + i, j = s[out].op.axis + jo, ji = s[out].split(j, factor=vec_size) + s[out].vectorize(ji) + joo, joi = s[out].split(jo, factor=warp_size) + s[out].bind(joi, te.thread_axis("threadIdx.x")) + _, jooi = s[out].split(joo, factor=num_warps) + s[out].bind(jooi, te.thread_axis("threadIdx.y")) + s[out].bind(i, te.thread_axis("blockIdx.x")) + + return s diff --git a/python/tvm/topi/generic/nn.py b/python/tvm/topi/generic/nn.py index 60ccd0d36abf..167473aac62a 100644 --- a/python/tvm/topi/generic/nn.py +++ b/python/tvm/topi/generic/nn.py @@ -762,3 +762,20 @@ def schedule_correlation_nchw(outs): The computation schedule for the op. """ return _default_schedule(outs, False) + + +def schedule_embed_grad(outs): + """Schedule for embed_grad + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of embed_grad + in the format of an array of tensors. + + Returns + ------- + sch: Schedule + The computation schedule for the op. + """ + return _default_schedule(outs, False) diff --git a/python/tvm/topi/nn/__init__.py b/python/tvm/topi/nn/__init__.py index 2ebbd1d67bd1..1d1a5c25105c 100644 --- a/python/tvm/topi/nn/__init__.py +++ b/python/tvm/topi/nn/__init__.py @@ -26,6 +26,7 @@ from .deformable_conv2d import * from .depthwise_conv2d import * from .elemwise import * +from .embed import * from .dilate import * from .flatten import * from .dense import * diff --git a/python/tvm/topi/nn/embed.py b/python/tvm/topi/nn/embed.py new file mode 100644 index 000000000000..1c7a01db07fc --- /dev/null +++ b/python/tvm/topi/nn/embed.py @@ -0,0 +1,41 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Embedding operators""" + +from tvm import te + + +@te.hybrid.script +def embed(table, indices): + out = output_tensor((indices.shape[0], table.shape[1]), table.dtype) + for i in range(indices.shape[0]): + for j in range(table.shape[1]): + out[i, j] = table[indices[i], j] + return out + + +@te.hybrid.script +def embed_grad(table, indices, grad_in): + grad_out = output_tensor(table.shape, table.dtype) + for i in range(table.shape[0]): + for j in range(table.shape[1]): + grad_out[i, j] = cast(table.dtype, 0.0) + for i in range(indices.shape[0]): + for j in range(table.shape[1]): + grad_out[indices[i], j] += grad_in[i, j] + return grad_out diff --git a/python/tvm/topi/x86/nn.py b/python/tvm/topi/x86/nn.py index 0994700fe98c..9095e648180d 100644 --- a/python/tvm/topi/x86/nn.py +++ b/python/tvm/topi/x86/nn.py @@ -69,3 +69,33 @@ def schedule_softmax(outs): s[exp].compute_at(s[softmax], fused_outer_axes) return s + + +def schedule_embed_grad(outs): + """Schedule for embed_grad + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of embed_grad + in the format of an array of tensors. + + Returns + ------- + sch: Schedule + The computation schedule for the op. + """ + s = te.create_schedule([outs[0].op]) + + vec_size = 8 # should autotune this, but we can't with hybrid script + out = s.outputs[0].output(0) + zi, zj, i, j = s[out].op.axis + zjo, zji = s[out].split(zj, factor=vec_size) + s[out].vectorize(zji) + s[out].parallel(zjo) + s[out].reorder(zjo, zi, zji) + jo, ji = s[out].split(j, factor=vec_size) + s[out].vectorize(ji) + s[out].parallel(i) + + return s diff --git a/src/relay/op/nn/embed.cc b/src/relay/op/nn/embed.cc new file mode 100644 index 000000000000..1da6b744c3a2 --- /dev/null +++ b/src/relay/op/nn/embed.cc @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file embed.cc + * \brief Property def of nn.embed operator. + */ + +#include +#include +#include + +#include + +#include "../../transforms/infer_layout_utils.h" + +namespace tvm { +namespace relay { + +TVM_REGISTER_NODE_TYPE(EmbedAttrs); + +bool EmbedRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + ICHECK_EQ(types.size(), 3) << "Embed shape relation takes three arguments: the embedding table, " + "the indices, and the output"; + const auto* table = types[0].as(); + const auto* indices = types[1].as(); + ICHECK_EQ(table->shape.size(), 2) << "Embed table must be a dimension 2 tensor."; + ICHECK_EQ(indices->shape.size(), 1) << "Embed indices must be a one dimensional vector."; + ICHECK(indices->dtype.is_int() || indices->dtype.is_uint()) << "Embed indices must be integers."; + + reporter->Assign(types[2], TensorType({indices->shape[0], table->shape[1]}, table->dtype)); + return true; +} + +Expr MakeEmbed(Expr table, Expr indices) { + auto attrs = make_object(); + static const Op& op = Op::Get("nn.embed"); + return Call(op, {table, indices}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relay.op.nn._make.embed").set_body([](const TVMArgs& args, TVMRetValue* rv) { + runtime::detail::unpack_call(MakeEmbed, args, rv); +}); + +RELAY_REGISTER_OP("nn.embed") + .describe(R"code(Lookup of indices in an embedding table. + +- **table**: M x N tensor +- **indices**: K long tensor of indices into `table` +- **out**: K x N tensor + +)code" TVM_ADD_FILELINE) + .set_attrs_type() + .set_num_inputs(2) + .add_argument("table", "2D Tensor", "Embedding table.") + .add_argument("indices", "1D Tensor", "Indices to lookup.") + .set_support_level(1) + .add_type_rel("Embed", EmbedRel); + +TVM_REGISTER_NODE_TYPE(EmbedGradAttrs); + +bool EmbedGradRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + ICHECK_EQ(types.size(), 4) << "EmbedGrad shape relation takes four arguments: the embedding " + "table, the indices, the gradient, and the output"; + const auto* table = types[0].as(); + const auto* indices = types[1].as(); + const auto* grad = types[2].as(); + ICHECK_EQ(table->shape.size(), 2) << "EmbedGrad table must be a dimension 2 tensor."; + ICHECK_EQ(indices->shape.size(), 1) << "EmbedGrad indices must be a one dimensional vector."; + ICHECK_EQ(grad->shape.size(), 2) << "EmbedGrad grad must be a dimension 2 tensor."; + ICHECK(indices->dtype.is_int() || indices->dtype.is_uint()) + << "EmbedGrad indices must be integers."; + reporter->AssertEQ(table->shape[0], grad->shape[0]); + reporter->AssertEQ(table->shape[1], grad->shape[1]); + + reporter->Assign(types[3], TensorType(table->shape, table->dtype)); + return true; +} + +Expr MakeEmbedGrad(Expr table, Expr indices, Expr grad) { + auto attrs = make_object(); + static const Op& op = Op::Get("nn.embed_grad"); + return Call(op, {table, indices, grad}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relay.op.nn._make.embed_grad") + .set_body([](const TVMArgs& args, TVMRetValue* rv) { + runtime::detail::unpack_call(MakeEmbedGrad, args, rv); + }); + +RELAY_REGISTER_OP("nn.embed_grad") + .describe(R"code(Gradient of Embed + +- **table**: M x N tensor +- **indices**: K long tensor of indices into `table` +- **grad**: K x N tensor of the gradient +- **out**: M x N tensor + +)code" TVM_ADD_FILELINE) + .set_attrs_type() + .set_num_inputs(3) + .add_argument("table", "2D Tensor", "EmbedGradding table.") + .add_argument("indices", "1D Tensor", "Indices to lookup.") + .add_argument("grad", "2D Tensor", "Gradient.") + .set_support_level(1) + .add_type_rel("EmbedGrad", EmbedGradRel); + +} // namespace relay +} // namespace tvm diff --git a/tests/python/relay/test_op_grad_level1.py b/tests/python/relay/test_op_grad_level1.py index 0ac604c6bca1..0f57e6dde6e1 100644 --- a/tests/python/relay/test_op_grad_level1.py +++ b/tests/python/relay/test_op_grad_level1.py @@ -160,5 +160,15 @@ def test_concatenate_grad(): check_grad(fwd_func) +def test_embed_grad(): + table = relay.var("table", shape=(6, 3), dtype="float64") + indices = relay.var("indices", shape=(4,), dtype="int64") + table_nd = np.reshape(np.arange(18), (6, 3)).astype("float64") + indices_nd = np.array([0, 0, 3, 2]).astype("int64") + fwd_func = relay.Function([table, indices], relay.nn.embed(table, indices)) + # Can't test against indices because the function is nonsmooth with respect to them. + check_grad(fwd_func, inputs=[table_nd, indices_nd], test_inputs=[table_nd]) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/topi/python/test_topi_embed.py b/tests/python/topi/python/test_topi_embed.py new file mode 100644 index 000000000000..078f7f63afb8 --- /dev/null +++ b/tests/python/topi/python/test_topi_embed.py @@ -0,0 +1,68 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +import tvm.testing +import tvm.topi.testing +import numpy as np +from tvm import topi + + +@tvm.testing.requires_llvm +def test_embed(): + M = 64 + N = 128 + K = 10 + table = np.reshape(np.arange(N * M), (M, N)).astype("float64") + indices = np.random.randint(M, size=(K,)) + out = table[indices, :] + tvm.testing.compare_numpy_tvm( + [table, indices], + out, + "llvm", + tvm.context("cpu"), + topi.nn.embed, + tvm.topi.testing.get_injective_schedule("llvm"), + ) + + +@tvm.testing.parametrize_targets +def test_embed_grad(ctx, target): + M = 30 + N = 50 + K = 10 + table = np.reshape(np.arange(N * M), (M, N)).astype("float64") + indices = np.random.randint(M, size=(K,)) + indices[0] = indices[-1] # ensure we have duplicate indices + grad = np.reshape(np.arange(K * N), (K, N)).astype(table.dtype) + grad_out = np.zeros((M, N)).astype(table.dtype) + + for i, ind in enumerate(indices): + grad_out[ind, :] += grad[i, :] + + implementations = { + "cpu": (topi.nn.embed_grad, topi.x86.schedule_embed_grad), + "gpu": (topi.nn.embed_grad, topi.cuda.schedule_embed_grad), + } + fcompute, fschedule = tvm.topi.testing.dispatch(target, implementations) + tvm.testing.compare_numpy_tvm( + [table, indices, grad], grad_out, target, ctx, fcompute, fschedule + ) + + +if __name__ == "__main__": + test_embed() + test_embed_grad(tvm.context("cpu"), "llvm")