diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 84318c384418..72ab391561dd 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -1119,9 +1119,11 @@ def _impl(inputs, attr, params, mod): def _gather_nd(): """GatherNd""" def _impl(inputs, attr, params, mod): + indices_dims = len(_infer_shape(inputs[1], mod)) + indices = _op.transpose(inputs[1], axes=[-1] + list(range(indices_dims-1))) return AttrCvt(op_name="gather_nd", ignores=['Tindices', 'Tparams',\ - 'Taxis', '_class'])(inputs, attr) + 'Taxis', '_class'])([inputs[0], indices], attr) return _impl def _stridedSlice(): diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 5c0391f6588d..8d9466cf1edc 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -1387,11 +1387,11 @@ def test_forward_gather(): def test_forward_gather_nd(): """test operator GatherNd""" - np_data = np.random.uniform(1, 100, size=(2, 2)).astype(np.float32) + np_data = np.random.uniform(1, 100, size=(2, 2, 2)).astype(np.float32) tf.reset_default_graph() with tf.Graph().as_default(): - in_data = tf.placeholder(tf.float32, (2, 2), name="in_data") - tf.gather_nd(in_data, indices=[[1, 0], [0, 1]], name="gather_nd") + in_data = tf.placeholder(tf.float32, (2, 2, 2), name="in_data") + tf.gather_nd(in_data, indices=[[1, 0, 0], [0, 0, 0]], name="gather_nd") compare_tf_with_tvm([np_data], ['in_data:0'], 'gather_nd:0')