From 00b635210ed41716890bbb9815777cff148e55be Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 8 Oct 2021 05:43:00 +0900 Subject: [PATCH 1/5] support gather op dynamic input --- python/tvm/relay/op/_transform.py | 9 +++++++++ src/relay/op/tensor/transform.cc | 6 ++++-- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 0284d2483ce5..9be58ab235f1 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -1174,3 +1174,12 @@ def gather_nd_shape_func(attrs, inputs, _): assert index_rank > 0, "index_rank needs to be specified for dynamic gather_nd" return [_gather_nd_shape(inputs[0], inputs[1], convert(batch_dims), convert(index_rank))] + + +@_reg.register_shape_func("gather", False) +def gather_shape_func(attrs, inputs, _): + """ + Shape func for gather operator. + """ + indices_shape = inputs[1] + return [te.compute((1,), lambda i: indices_shape[i])] diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 3781107eeee1..fa5b31a8abef 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -3260,8 +3260,10 @@ bool GatherRel(const Array& types, int num_inputs, const Attrs& attrs, oshape.reserve(ndim_data); for (size_t i = 0; i < ndim_data; ++i) { if (i == static_cast(axis)) { - const int64_t* indice_shape_i = tir::as_const_int(indices->shape[i]); - ICHECK_GE(*indice_shape_i, 1); + if (indices->shape[i].as()) { + const int64_t* indice_shape_i = tir::as_const_int(indices->shape[i]); + ICHECK_GE(*indice_shape_i, 1); + } } else { ICHECK(reporter->AssertEQ(indices->shape[i], data->shape[i])); } From a102accafc0b61e570d0b61c6b518acff66901db Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 8 Oct 2021 19:18:07 +0900 Subject: [PATCH 2/5] fix shape func and add test --- python/tvm/relay/op/_transform.py | 13 +++++++++++-- tests/python/relay/test_any.py | 25 ++++++++++++++++++++++++- 2 files changed, 35 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 9be58ab235f1..71d61c801f97 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -1176,10 +1176,19 @@ def gather_nd_shape_func(attrs, inputs, _): return [_gather_nd_shape(inputs[0], inputs[1], convert(batch_dims), convert(index_rank))] +@script +def _gather_shape(data_shape, indices_shape, axis): + out_shape = output_tensor((1,), "int64") + for i in range(data_shape.shape[0]): + if i != axis: + assert data_shape[i] == indices_shape[i], "data and indices size at non-gather axes must be the same" + out_shape[i] = indices_shape[i] + return out_shape + + @_reg.register_shape_func("gather", False) def gather_shape_func(attrs, inputs, _): """ Shape func for gather operator. """ - indices_shape = inputs[1] - return [te.compute((1,), lambda i: indices_shape[i])] + return [_gather_shape(inputs[0], inputs[1], attrs.axis)] diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index decddc1ef0a4..bd5b1c6a3be2 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -2064,5 +2064,28 @@ def verify_scatter_nd(data_np, indices_np, updates_np, ref_res): verify_scatter_nd(data, indices, updates, out) +@tvm.testing.uses_gpu +def test_gather(): + def verify_gather(data_shape, indices_shape, data_shape_np, indices_shape_np, axis): + x = relay.var("x", relay.TensorType(data_shape, "float32")) + y = relay.var("y", relay.TensorType(indices_shape, "int32")) + z = relay.gather(x, axis, y) + + mod = tvm.IRModule() + mod["main"] = relay.Function([x, y], z) + + data_np = np.random.uniform(size=data_shape_np).astype("float32") + indices_np = np.random.randint(low=0, high=2, size=indices_shape_np, dtype="int32") + + ref_res = tvm.topi.testing.gather_python(data_np, axis, indices_np) + check_result([data_np, indices_np], mod, [ref_res]) + + verify_gather((relay.Any(),), (relay.Any(),), (10,), (10,), 0) + verify_gather((2, 2), (2, relay.Any()), (2, 2), (2, 3), 1) + verify_gather((relay.Any(), 2), (2, relay.Any()), (2, 2), (2, 3), 1) + verify_gather((relay.Any(), relay.Any()), (relay.Any(), relay.Any()), (2, 3), (1, 3), 0) + + if __name__ == "__main__": - pytest.main([__file__]) + # pytest.main([__file__]) + test_gather() From 20c4e66cdfc68a3ab9f2f96573b0b76dd9565fb0 Mon Sep 17 00:00:00 2001 From: masa Date: Fri, 8 Oct 2021 20:34:24 +0900 Subject: [PATCH 3/5] remove constness check --- include/tvm/topi/transform.h | 2 -- 1 file changed, 2 deletions(-) diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index 8d1a49a4cc5f..37bac33485df 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -1233,8 +1233,6 @@ inline Tensor gather(const Tensor& data, int axis, const Tensor& indices, } ICHECK_GE(axis, 0); ICHECK_LT(axis, ndim_d); - size_t indices_dim_i = static_cast(GetConstInt(indices->shape[axis])); - ICHECK_GE(indices_dim_i, 1); ICHECK(indices->dtype.is_int()); Array out_shape; From 5667fb0b1ba09f23b8ded6e283fbde7fb7329cdf Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 11 Oct 2021 15:27:05 +0900 Subject: [PATCH 4/5] fix shape func output rank --- python/tvm/relay/op/_transform.py | 6 ++++-- tests/python/relay/test_any.py | 3 +-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 71d61c801f97..76c806905b18 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -1178,10 +1178,12 @@ def gather_nd_shape_func(attrs, inputs, _): @script def _gather_shape(data_shape, indices_shape, axis): - out_shape = output_tensor((1,), "int64") + out_shape = output_tensor((data_shape.shape[0],), "int64") for i in range(data_shape.shape[0]): if i != axis: - assert data_shape[i] == indices_shape[i], "data and indices size at non-gather axes must be the same" + assert ( + data_shape[i] == indices_shape[i] + ), "data and indices size at non-gather axes must be the same" out_shape[i] = indices_shape[i] return out_shape diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index bd5b1c6a3be2..8788faf45866 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -2087,5 +2087,4 @@ def verify_gather(data_shape, indices_shape, data_shape_np, indices_shape_np, ax if __name__ == "__main__": - # pytest.main([__file__]) - test_gather() + pytest.main([__file__]) From daa57838b674d1f669f7f94d2fe21d43ee476dad Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 11 Oct 2021 15:31:21 +0900 Subject: [PATCH 5/5] restore check --- include/tvm/topi/transform.h | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index 37bac33485df..3df9caf55d5c 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -1233,6 +1233,10 @@ inline Tensor gather(const Tensor& data, int axis, const Tensor& indices, } ICHECK_GE(axis, 0); ICHECK_LT(axis, ndim_d); + if (indices->shape[axis].as()) { + size_t indices_dim_i = static_cast(GetConstInt(indices->shape[axis])); + ICHECK_GE(indices_dim_i, 1); + } ICHECK(indices->dtype.is_int()); Array out_shape;