From 4f7525c945fcddb6928e93261386688c760f0d64 Mon Sep 17 00:00:00 2001 From: cailun01 <1354213521@qq.com> Date: Fri, 18 Jun 2021 09:34:25 +0800 Subject: [PATCH 1/7] [TensorFlow][Frontend] Adding InversePermutation Op Computes the inverse permutation of a tensor. This Op is used by Mask R-CNN or other object detection models. --- python/tvm/relay/frontend/tensorflow_ops.py | 1 + python/tvm/relay/op/_transform.py | 4 + python/tvm/relay/op/strategy/cuda.py | 12 ++ python/tvm/relay/op/strategy/generic.py | 22 ++++ python/tvm/relay/op/transform.py | 28 ++++ python/tvm/topi/cuda/transform.py | 69 +++++++++- python/tvm/topi/transform.py | 31 +++++ src/relay/op/tensor/transform.cc | 19 +++ .../frontend/tensorflow/test_forward.py | 122 ++++++++++-------- 9 files changed, 255 insertions(+), 53 deletions(-) diff --git a/python/tvm/relay/frontend/tensorflow_ops.py b/python/tvm/relay/frontend/tensorflow_ops.py index 3c4a9b69ea6e..be15f83faf0f 100644 --- a/python/tvm/relay/frontend/tensorflow_ops.py +++ b/python/tvm/relay/frontend/tensorflow_ops.py @@ -2886,6 +2886,7 @@ def _impl(inputs, attr, params, mod): "GreaterEqual": _broadcast("greater_equal"), "Identity": _identity(), "IdentityN": _identityn(), + "InvertPermutation": AttrCvt("invert_permutation"), "IsFinite": AttrCvt("isfinite"), "IsInf": AttrCvt("isinf"), "IsNan": AttrCvt("isnan"), diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index f87b5ed0b8ef..bee188f19364 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -178,6 +178,10 @@ def compute_unique(attrs, inputs, output_type): _reg.register_strategy("unique", strategy.unique_strategy) +# invert_permutation +_reg.register_strategy("invert_permutation", strategy.invert_permutation_strategy) +_reg.register_shape_func("invert_permutation", False, elemwise_shape_func) + ##################### # Shape functions # ##################### diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index b4db412700a7..6418f1f96b3b 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -1135,3 +1135,15 @@ def schedule_transpose_cuda(attrs, outs, target): ): return topi.cuda.schedule_transpose(outs) return schedule_injective(attrs, outs, target) + + +@invert_permutation_strategy.register(["cuda", "gpu"]) +def invert_permutation_strategy_cuda(attrs, inputs, out_type, target): + """invert_permutation cuda strategy""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_invert_permutation(topi.cuda.invert_permutation), + wrap_topi_schedule(topi.cuda.vision._default_schedule), + name="invert_permutation.cuda", + ) + return strategy diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index d56820e409aa..35e5177458a5 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -1607,3 +1607,25 @@ def schedule_transpose(attrs, outs, target): """schedule transpose""" with target: return schedule_injective(attrs, outs, target) + + +# invert_permutation +def wrap_compute_invert_permutation(topi_compute): + """wrap invert_permutation topi compute""" + + def _compute_invert_permutation(attrs, inputs, out_type): + return [topi_compute(inputs[0])] + + return _compute_invert_permutation + + +@override_native_generic_func("invert_permutation_strategy") +def invert_permutation_strategy(attrs, inputs, out_type, target): + """invert_permutation generic strategy""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_invert_permutation(topi.invert_permutation), + wrap_topi_schedule(topi.generic.schedule_injective), + name="invert_permutation.generic", + ) + return strategy diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 440d2fae042f..9cb50ed6548a 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -1720,3 +1720,31 @@ def unique(data, is_sorted=True, return_counts=False): if return_counts: return TupleWrapper(_make.unique(data, is_sorted, return_counts), 5) return TupleWrapper(_make.unique(data, is_sorted, return_counts), 4) + + +def invert_permutation(data): + """Computes the inverse permutation of data. + This operation computes the inverse of an index permutation. + It takes a 1-D integer tensor x, which represents the indices of a zero-based + array and swaps each value with its index position. + + For an output tensor y and an input tensor x, this operation computes the following: + y[x[i]] = i for i in [0, 1, ..., len(x) - 1] + + Parameters + ---------- + data : relay.Expr + The source data to be invert permuated. + + Returns + ------- + ret : relay.Expr + Invert permuated data. Has the same type as data. + + Examples + -------- + .. code-block:: python + data = [3, 4, 0, 2, 1] + relay.invert_permutation(data) = [2, 4, 3, 0, 1] + """ + return _make.invert_permutation(data) diff --git a/python/tvm/topi/cuda/transform.py b/python/tvm/topi/cuda/transform.py index 89caf94bbbc1..28a37f1e5f4c 100644 --- a/python/tvm/topi/cuda/transform.py +++ b/python/tvm/topi/cuda/transform.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """CUDA implementations of transforms""" - +import tvm from ... import te from ...target import Target from ..utils import traverse_inline @@ -65,3 +65,70 @@ def _callback(op): s[c].bind(ao, thread_y) traverse_inline(s, out.op, _callback) + + +def _invert_permutation_ir(data, out): + """Low level IR to get invert_permutation. + + Parameters + ---------- + data : Buffer + Input data. 1-D Buffer with shape [elem_num]. + + out : Buffer + 1D buffer for invert permutation result with the same shape with data. + + Returns + ------- + stmt : Stmt + The result IR statement. + """ + elem_num = data.shape[0] + + ib = tvm.tir.ir_builder.create() + data = ib.buffer_ptr(data) + out = ib.buffer_ptr(out) + + max_threads = int(Target.current(allow_none=False).max_num_threads) + nthread_tx = max_threads + nthread_bx = elem_num // max_threads + 1 + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + tid = bx * max_threads + tx + + with ib.if_scope(tid < elem_num): + r_ind = data[tid] + out[r_ind] = tid + return ib.get() + + +def invert_permutation(data): + """Compute definition of invert_permutation. + For an output tensor y and an input tensor x, this operation computes the following: + + y[x[i]] = i for i in [0, 1, ..., len(x) - 1] + + Parameters + ---------- + data : tvm.te.Tensor + 1-D tensor + + Returns + ------- + out : tvm.te.Tensor + """ + data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8) + out_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "out_buf", data_alignment=8) + + out = te.extern( + [data.shape], + [data], + lambda ins, outs: _invert_permutation_ir(ins[0], outs[0]), + in_buffers=[data_buf,], + out_buffers=[out_buf,], + name="invert_permutation", + tag="invert_permutation_gpu", + ) + return out diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py index b4d0167be2b1..8294d1e53ac0 100644 --- a/python/tvm/topi/transform.py +++ b/python/tvm/topi/transform.py @@ -20,6 +20,7 @@ import tvm from tvm import te from tvm import topi +from tvm.te import hybrid from . import cpp from . import tag from .utils import within_index, make_idx, const_vector @@ -941,3 +942,33 @@ def adv_index(data, indices): Output tensor """ return cpp.adv_index(data, indices) + + +@hybrid.script +def invert_permutation(data): + """Computes the inverse permutation of data. + + Parameters + ---------- + data : tvm.te.Tensor + Input data + + Returns + ------- + result : tvm.te.Tensor + Output tensor + + Examples + -------- + .. code-block:: python + + data = [3, 4, 0, 2, 1] + + topi.invert_permutation(data) = [2, 4, 3, 0, 1] + """ + result = output_tensor(data.shape, data.dtype) + nums = data.shape[0] + for ind in range(nums): + r_ind = data[ind] + result[r_ind] = ind + return result diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 7d40bf22bcee..643353d7da6c 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -3976,5 +3976,24 @@ RELAY_REGISTER_OP("unique") .add_type_rel("unique", UniqueRel) .set_support_level(3) .set_attr("TOpPattern", kOpaque); + +// invert_permutation +Expr MakeInvertPermutation(Expr data) { + static const Op& op = Op::Get("invert_permutation"); + return Call(op, {data}, Attrs(), {}); +} + +TVM_REGISTER_GLOBAL("relay.op._make.invert_permutation").set_body_typed(MakeInvertPermutation); + +RELAY_REGISTER_OP("invert_permutation") + .describe( + R"doc(Computes the inverse permutation of a tensor.)doc" TVM_ADD_FILELINE) + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .add_type_rel("Identity", IdentityRel) + .set_support_level(1) + .set_attr("TOpPattern", kInjective) + .set_attr("TOpIsStateful", false); + } // namespace relay } // namespace tvm diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 57497d04706a..c1f303f18e07 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -1717,58 +1717,58 @@ def test_forward_variable(): _test_variable(np.random.uniform(size=(32, 100)).astype("float32")) -@tvm.testing.parametrize_targets("llvm", "cuda") -def test_read_variable_op(target, dev): - """Read Variable op test""" - - tf.reset_default_graph() - data = np.random.uniform(size=(32, 100)).astype("float32") - input_tensor = array_ops.placeholder(shape=data.shape, dtype=data.dtype) - - size = input_tensor.shape.dims[1] - var_data = np.random.uniform(-5, 5, size=[size, size]).astype(np.float32) - input_var = tf.Variable(var_data, name="var1", use_resource=True) - math_ops.matmul(input_tensor, input_var) - - out_name = ["MatMul:0"] - out_node = ["MatMul"] - in_name = ["Placeholder:0"] - in_node = ["Placeholder"] - in_data = [data] - - with tf.Session() as sess: - sess.run(variables.global_variables_initializer()) - - final_graph_def = sess.graph.as_graph_def(add_shapes=True) - tf_output = run_tf_graph(sess, in_data, in_name, out_name) - - shape_dict = {e: i.shape for e, i in zip(in_name, in_data)} - with pytest.raises(Exception) as execinfo: - mod, params = relay.frontend.from_tensorflow( - final_graph_def, layout=None, shape=shape_dict, outputs=None - ) - - assert execinfo.value.args[0].startswith("Graph is not frozen. Provide a frozen graph") - - # Now convert the variables to constant and run inference on the converted graph - final_graph_def = tf.graph_util.convert_variables_to_constants( - sess, - sess.graph.as_graph_def(add_shapes=True), - out_node, - ) - - tvm_output = run_tvm_graph( - final_graph_def, - in_data, - in_node, - target=target, - out_names=out_name, - num_output=len(out_name), - ) - for i in range(len(tf_output)): - tvm.testing.assert_allclose(tf_output[i], tvm_output[i], atol=1e-4, rtol=1e-5) - - sess.close() +# @tvm.testing.parametrize_targets("llvm", "cuda") +# def test_read_variable_op(target, dev): +# """Read Variable op test""" + +# tf.reset_default_graph() +# data = np.random.uniform(size=(32, 100)).astype("float32") +# input_tensor = array_ops.placeholder(shape=data.shape, dtype=data.dtype) + +# size = input_tensor.shape.dims[1] +# var_data = np.random.uniform(-5, 5, size=[size, size]).astype(np.float32) +# input_var = tf.Variable(var_data, name="var1", use_resource=True) +# math_ops.matmul(input_tensor, input_var) + +# out_name = ["MatMul:0"] +# out_node = ["MatMul"] +# in_name = ["Placeholder:0"] +# in_node = ["Placeholder"] +# in_data = [data] + +# with tf.Session() as sess: +# sess.run(variables.global_variables_initializer()) + +# final_graph_def = sess.graph.as_graph_def(add_shapes=True) +# tf_output = run_tf_graph(sess, in_data, in_name, out_name) + +# shape_dict = {e: i.shape for e, i in zip(in_name, in_data)} +# with pytest.raises(Exception) as execinfo: +# mod, params = relay.frontend.from_tensorflow( +# final_graph_def, layout=None, shape=shape_dict, outputs=None +# ) + +# assert execinfo.value.args[0].startswith("Graph is not frozen. Provide a frozen graph") + +# # Now convert the variables to constant and run inference on the converted graph +# final_graph_def = tf.graph_util.convert_variables_to_constants( +# sess, +# sess.graph.as_graph_def(add_shapes=True), +# out_node, +# ) + +# tvm_output = run_tvm_graph( +# final_graph_def, +# in_data, +# in_node, +# target=target, +# out_names=out_name, +# num_output=len(out_name), +# ) +# for i in range(len(tf_output)): +# tvm.testing.assert_allclose(tf_output[i], tvm_output[i], atol=1e-4, rtol=1e-5) + +# sess.close() ####################################################################### @@ -5569,5 +5569,23 @@ def @main(%A: Tensor[(4, 176, 8, 8), float32]) { tvm.ir.assert_structural_equal(mod["main"].body, mod_golden["main"].body, map_free_vars=True) +####################################################################### +# invert_permutation +# -------------------- + + +def test_invert_permutation(): + """test InvertPermutation""" + tf.reset_default_graph() + + input_shape = [6] + x = np.array([3, 4, 0, 2, 1, 5]).astype("int32") + with tf.Graph().as_default(): + in_data = tf.placeholder(shape=input_shape, dtype="int32") + tf.invert_permutation(in_data) + out_name = "InvertPermutation:0" + compare_tf_with_tvm(x, "Placeholder:0", out_name, no_gpu=False) + + if __name__ == "__main__": pytest.main([__file__]) From d83178b7e02fdfc2d1a2b247329114e2b4122e0e Mon Sep 17 00:00:00 2001 From: cailun01 <1354213521@qq.com> Date: Fri, 18 Jun 2021 17:42:26 +0800 Subject: [PATCH 2/7] uncomment test_read_variable_op --- python/tvm/topi/transform.py | 2 - .../frontend/tensorflow/test_forward.py | 121 ++++++++---------- 2 files changed, 52 insertions(+), 71 deletions(-) diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py index 8294d1e53ac0..45756eadbcdb 100644 --- a/python/tvm/topi/transform.py +++ b/python/tvm/topi/transform.py @@ -961,9 +961,7 @@ def invert_permutation(data): Examples -------- .. code-block:: python - data = [3, 4, 0, 2, 1] - topi.invert_permutation(data) = [2, 4, 3, 0, 1] """ result = output_tensor(data.shape, data.dtype) diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index c1f303f18e07..c66118d6d731 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -1717,58 +1717,58 @@ def test_forward_variable(): _test_variable(np.random.uniform(size=(32, 100)).astype("float32")) -# @tvm.testing.parametrize_targets("llvm", "cuda") -# def test_read_variable_op(target, dev): -# """Read Variable op test""" - -# tf.reset_default_graph() -# data = np.random.uniform(size=(32, 100)).astype("float32") -# input_tensor = array_ops.placeholder(shape=data.shape, dtype=data.dtype) - -# size = input_tensor.shape.dims[1] -# var_data = np.random.uniform(-5, 5, size=[size, size]).astype(np.float32) -# input_var = tf.Variable(var_data, name="var1", use_resource=True) -# math_ops.matmul(input_tensor, input_var) - -# out_name = ["MatMul:0"] -# out_node = ["MatMul"] -# in_name = ["Placeholder:0"] -# in_node = ["Placeholder"] -# in_data = [data] - -# with tf.Session() as sess: -# sess.run(variables.global_variables_initializer()) - -# final_graph_def = sess.graph.as_graph_def(add_shapes=True) -# tf_output = run_tf_graph(sess, in_data, in_name, out_name) - -# shape_dict = {e: i.shape for e, i in zip(in_name, in_data)} -# with pytest.raises(Exception) as execinfo: -# mod, params = relay.frontend.from_tensorflow( -# final_graph_def, layout=None, shape=shape_dict, outputs=None -# ) - -# assert execinfo.value.args[0].startswith("Graph is not frozen. Provide a frozen graph") - -# # Now convert the variables to constant and run inference on the converted graph -# final_graph_def = tf.graph_util.convert_variables_to_constants( -# sess, -# sess.graph.as_graph_def(add_shapes=True), -# out_node, -# ) - -# tvm_output = run_tvm_graph( -# final_graph_def, -# in_data, -# in_node, -# target=target, -# out_names=out_name, -# num_output=len(out_name), -# ) -# for i in range(len(tf_output)): -# tvm.testing.assert_allclose(tf_output[i], tvm_output[i], atol=1e-4, rtol=1e-5) - -# sess.close() +@tvm.testing.parametrize_targets("llvm", "cuda") +def test_read_variable_op(target, dev): + """Read Variable op test""" + + tf.reset_default_graph() + data = np.random.uniform(size=(32, 100)).astype("float32") + input_tensor = array_ops.placeholder(shape=data.shape, dtype=data.dtype) + + size = input_tensor.shape.dims[1] + var_data = np.random.uniform(-5, 5, size=[size, size]).astype(np.float32) + input_var = tf.Variable(var_data, name="var1", use_resource=True) + math_ops.matmul(input_tensor, input_var) + + out_name = ["MatMul:0"] + out_node = ["MatMul"] + in_name = ["Placeholder:0"] + in_node = ["Placeholder"] + in_data = [data] + + with tf.Session() as sess: + sess.run(variables.global_variables_initializer()) + + final_graph_def = sess.graph.as_graph_def(add_shapes=True) + tf_output = run_tf_graph(sess, in_data, in_name, out_name) + + shape_dict = {e: i.shape for e, i in zip(in_name, in_data)} + with pytest.raises(Exception) as execinfo: + mod, params = relay.frontend.from_tensorflow( + final_graph_def, layout=None, shape=shape_dict, outputs=None + ) + + assert execinfo.value.args[0].startswith("Graph is not frozen. Provide a frozen graph") + + # Now convert the variables to constant and run inference on the converted graph + final_graph_def = tf.graph_util.convert_variables_to_constants( + sess, + sess.graph.as_graph_def(add_shapes=True), + out_node, + ) + + tvm_output = run_tvm_graph( + final_graph_def, + in_data, + in_node, + target=target, + out_names=out_name, + num_output=len(out_name), + ) + for i in range(len(tf_output)): + tvm.testing.assert_allclose(tf_output[i], tvm_output[i], atol=1e-4, rtol=1e-5) + + sess.close() ####################################################################### @@ -1843,9 +1843,6 @@ def test_forward_batch_matmul(): _test_batch_matmul((1, 2, 3, 4, 5, 6), (1, 2, 3, 4, 6, 5), "float32", True, True) _test_batch_matmul((3, 4, 5, 6), (3, 4, 5, 6), "int32", True, False) _test_batch_matmul((2, 3, 4, 2, 3, 4, 5, 6), (2, 3, 4, 2, 3, 4, 5, 6), "float32", False, True) - _test_batch_matmul((1, 8, 64, 2), (2, 1), "float32", False, False) - _test_batch_matmul((1, 8, 8, 64), (64, 1), "float32", False, False) - _test_batch_matmul((1, 8, 64), (64, 1), "float32", False, False) @tvm.testing.requires_cuda @@ -1873,20 +1870,6 @@ def test_forward_batch_matmul_dynamic(): (2, 3, 4, 6, 5), "float32", ) - _test_batch_matmul_dynamic( - (None, None, None, 5, 6), - (6, None), - (2, 3, 4, 5, 6), - (6, 1), - "float32", - ) - _test_batch_matmul_dynamic( - (None, 5, 6), - (6, None), - (24, 5, 6), - (6, 1), - "float32", - ) ####################################################################### From d93b2c96d5001aca64a16c120553214616b5a62b Mon Sep 17 00:00:00 2001 From: cailun01 <1354213521@qq.com> Date: Fri, 18 Jun 2021 18:49:49 +0800 Subject: [PATCH 3/7] restore several tests --- .../python/frontend/tensorflow/test_forward.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index c66118d6d731..0ef3317525b3 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -1843,6 +1843,9 @@ def test_forward_batch_matmul(): _test_batch_matmul((1, 2, 3, 4, 5, 6), (1, 2, 3, 4, 6, 5), "float32", True, True) _test_batch_matmul((3, 4, 5, 6), (3, 4, 5, 6), "int32", True, False) _test_batch_matmul((2, 3, 4, 2, 3, 4, 5, 6), (2, 3, 4, 2, 3, 4, 5, 6), "float32", False, True) + _test_batch_matmul((1, 8, 64, 2), (2, 1), "float32", False, False) + _test_batch_matmul((1, 8, 8, 64), (64, 1), "float32", False, False) + _test_batch_matmul((1, 8, 64), (64, 1), "float32", False, False) @tvm.testing.requires_cuda @@ -1870,6 +1873,20 @@ def test_forward_batch_matmul_dynamic(): (2, 3, 4, 6, 5), "float32", ) + _test_batch_matmul_dynamic( + (None, None, None, 5, 6), + (6, None), + (2, 3, 4, 5, 6), + (6, 1), + "float32", + ) + _test_batch_matmul_dynamic( + (None, 5, 6), + (6, None), + (24, 5, 6), + (6, 1), + "float32", + ) ####################################################################### From 15a4564745a7136b98ec6ef76a7aeff0d3c2fcfe Mon Sep 17 00:00:00 2001 From: cailun01 <1354213521@qq.com> Date: Sun, 20 Jun 2021 14:18:37 +0800 Subject: [PATCH 4/7] fix lint error --- src/relay/op/tensor/transform.cc | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 643353d7da6c..b3a34e83b70a 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -3290,8 +3290,6 @@ which must just be not null. Output will have same shape as ``indices``. .set_attr("FTVMCompute", GatherCompute) .set_attr("TOpPattern", kInjective); -TVM_REGISTER_NODE_TYPE(GatherNDAttrs); - // gather_nd operator bool GatherNDRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { @@ -3369,7 +3367,6 @@ When B == 0 (the default case), the output shape will be (Y_0, ..., Y_{K-1}, X_M In both cases, if M + B == N, the output shape will simply be (Y_0, ..., Y_{K-1}). )code" TVM_ADD_FILELINE) .set_num_inputs(2) - .set_attrs_type() .add_argument("data", "Tensor", "The input tensor.") .add_argument("indices", "Tensor", "The indices of values to gather.") .set_support_level(3) @@ -3986,8 +3983,7 @@ Expr MakeInvertPermutation(Expr data) { TVM_REGISTER_GLOBAL("relay.op._make.invert_permutation").set_body_typed(MakeInvertPermutation); RELAY_REGISTER_OP("invert_permutation") - .describe( - R"doc(Computes the inverse permutation of a tensor.)doc" TVM_ADD_FILELINE) + .describe(R"doc(Computes the inverse permutation of a tensor.)doc" TVM_ADD_FILELINE) .set_num_inputs(1) .add_argument("data", "Tensor", "The input tensor.") .add_type_rel("Identity", IdentityRel) From 40f90265c5a2ddb87b94548ce8d372550a94689d Mon Sep 17 00:00:00 2001 From: cailun01 <1354213521@qq.com> Date: Sun, 20 Jun 2021 15:10:09 +0800 Subject: [PATCH 5/7] fix python linting error --- python/tvm/topi/cuda/transform.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/python/tvm/topi/cuda/transform.py b/python/tvm/topi/cuda/transform.py index 28a37f1e5f4c..5d6ee1f7b9f2 100644 --- a/python/tvm/topi/cuda/transform.py +++ b/python/tvm/topi/cuda/transform.py @@ -126,8 +126,12 @@ def invert_permutation(data): [data.shape], [data], lambda ins, outs: _invert_permutation_ir(ins[0], outs[0]), - in_buffers=[data_buf,], - out_buffers=[out_buf,], + in_buffers=[ + data_buf, + ], + out_buffers=[ + out_buf, + ], name="invert_permutation", tag="invert_permutation_gpu", ) From add05bba15d76087ff15a1721deed67eb279e006 Mon Sep 17 00:00:00 2001 From: cailun01 <1354213521@qq.com> Date: Sun, 20 Jun 2021 16:58:56 +0800 Subject: [PATCH 6/7] fix lint error --- python/tvm/topi/cuda/transform.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/python/tvm/topi/cuda/transform.py b/python/tvm/topi/cuda/transform.py index 5d6ee1f7b9f2..16b1273def47 100644 --- a/python/tvm/topi/cuda/transform.py +++ b/python/tvm/topi/cuda/transform.py @@ -85,23 +85,23 @@ def _invert_permutation_ir(data, out): """ elem_num = data.shape[0] - ib = tvm.tir.ir_builder.create() - data = ib.buffer_ptr(data) - out = ib.buffer_ptr(out) + irb = tvm.tir.ir_builder.create() + data = irb.buffer_ptr(data) + out = irb.buffer_ptr(out) max_threads = int(Target.current(allow_none=False).max_num_threads) nthread_tx = max_threads nthread_bx = elem_num // max_threads + 1 - tx = te.thread_axis("threadIdx.x") - bx = te.thread_axis("blockIdx.x") - ib.scope_attr(tx, "thread_extent", nthread_tx) - ib.scope_attr(bx, "thread_extent", nthread_bx) - tid = bx * max_threads + tx + thread_x = te.thread_axis("threadIdx.x") + block_x = te.thread_axis("blockIdx.x") + irb.scope_attr(thread_x, "thread_extent", nthread_tx) + irb.scope_attr(block_x, "thread_extent", nthread_bx) + tid = block_x * max_threads + thread_x - with ib.if_scope(tid < elem_num): + with irb.if_scope(tid < elem_num): r_ind = data[tid] out[r_ind] = tid - return ib.get() + return irb.get() def invert_permutation(data): From 9829af47201bc3b8b5c503578dcd3963de4e3c87 Mon Sep 17 00:00:00 2001 From: cailun01 <1354213521@qq.com> Date: Sun, 20 Jun 2021 21:45:09 +0800 Subject: [PATCH 7/7] restore mistakenly deleted codes --- src/relay/op/tensor/transform.cc | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index b3a34e83b70a..5dc2a677f13f 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -3290,6 +3290,8 @@ which must just be not null. Output will have same shape as ``indices``. .set_attr("FTVMCompute", GatherCompute) .set_attr("TOpPattern", kInjective); +TVM_REGISTER_NODE_TYPE(GatherNDAttrs); + // gather_nd operator bool GatherNDRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { @@ -3367,6 +3369,7 @@ When B == 0 (the default case), the output shape will be (Y_0, ..., Y_{K-1}, X_M In both cases, if M + B == N, the output shape will simply be (Y_0, ..., Y_{K-1}). )code" TVM_ADD_FILELINE) .set_num_inputs(2) + .set_attrs_type() .add_argument("data", "Tensor", "The input tensor.") .add_argument("indices", "Tensor", "The indices of values to gather.") .set_support_level(3)