diff --git a/include/tvm/relax/attrs/manipulate.h b/include/tvm/relax/attrs/manipulate.h index ea41488354d8..e6c16d233a6b 100644 --- a/include/tvm/relax/attrs/manipulate.h +++ b/include/tvm/relax/attrs/manipulate.h @@ -152,6 +152,23 @@ struct FlipAttrs : public tvm::AttrsNode { } }; // struct FlipAttrs +/*! \brief Attributes used in gather_elements operators */ +struct GatherElementsAttrs : public tvm::AttrsNode { + Integer axis; + + TVM_DECLARE_ATTRS(GatherElementsAttrs, "relax.attrs.GatherElementsAttrs") { + TVM_ATTR_FIELD(axis).set_default(0).describe("The axis along which to index."); + } +}; // struct GatherElementsAttrs + +/*! \brief Attributes used in gather_nd operators */ +struct GatherNDAttrs : public tvm::AttrsNode { + Integer batch_dims; + TVM_DECLARE_ATTRS(GatherNDAttrs, "relax.attrs.GatherNDAttrs") { + TVM_ATTR_FIELD(batch_dims).set_default(Integer(0)).describe("The number of batch dims."); + } +}; // struct GatherNDAttrs + /*! \brief Attributes used in scatter_elements operators */ struct ScatterElementsAttrs : public tvm::AttrsNode { Integer axis; diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index eb7a3eaf3628..dc2d9c6193ed 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -781,6 +781,24 @@ def _impl_v13(cls, bb, inputs, attr, params): return relax.op.take(data, indices, axis) +class GatherElements(OnnxOpConverter): + """Convert an onnx GatherElements node into an equivalent Relax expression.""" + + @classmethod + def _impl_v13(cls, bb, inputs, attr, params): + axis = attr.get("axis", 0) + return relax.op.gather_elements(inputs[0], inputs[1], axis) + + +class GatherND(OnnxOpConverter): + """Convert an onnx GatherND node into an equivalent Relax expression.""" + + @classmethod + def _impl_v13(cls, bb, inputs, attr, params): + batch_dims = attr.get("batch_dims", 0) + return relax.op.gather_nd(inputs[0], inputs[1], batch_dims) + + class Scatter(OnnxOpConverter): """Convert an onnx Scatter node into an equivalent Relax expression.""" @@ -3070,8 +3088,8 @@ def _get_convert_map(): "Squeeze": Squeeze, "Constant": Constant, "Gather": Gather, - # "GatherElements": GatherElements, - # "GatherND": GatherND, + "GatherElements": GatherElements, + "GatherND": GatherND, "Scatter": Scatter, "ScatterElements": ScatterElements, "ScatterND": ScatterND, diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py index 1603ea2f0f7e..97f18a239640 100644 --- a/python/tvm/relax/op/__init__.py +++ b/python/tvm/relax/op/__init__.py @@ -92,6 +92,8 @@ expand_dims, flatten, flip, + gather_elements, + gather_nd, layout_transform, one_hot, permute_dims, diff --git a/python/tvm/relax/op/manipulate.py b/python/tvm/relax/op/manipulate.py index 3210cc821689..0f6e537ab3d6 100644 --- a/python/tvm/relax/op/manipulate.py +++ b/python/tvm/relax/op/manipulate.py @@ -435,6 +435,79 @@ def flip(data, axis): return _ffi_api.flip(data, axis) # type: ignore +def gather_elements(data: Expr, indices: Expr, axis: int = 0) -> Expr: + """Gather elements from data according to indices along the specified axis. + + Parameters + ---------- + data : relax.Expr + The input data to the operator. + + indices : relax.Expr + The indices tensor, must have integer type. + + axis : int + The axis along which to index. Default is 0. + + Returns + ------- + ret : relax.Expr + The computed result. + + Examples + -------- + .. code-block:: python + + data = [[1, 2], [3, 4]] + indices = [[0, 0], [1, 0]] + axis = 1 + output = [[1, 1], [4, 3]] + + data = [[1, 2, 3], [4, 5, 6]] + indices = [[1, 1, 1]] + axis = 0 + output = [[4, 5, 6]] + """ + return _ffi_api.gather_elements(data, indices, axis) # type: ignore + + +def gather_nd(data: Expr, indices: Expr, batch_dims: int = 0) -> Expr: + """Update data at positions defined by indices with values in updates. + + Parameters + ---------- + data : relax.Expr + The input data to the operator. + + indices : relax.Expr + The indices tensor, must have integer type. + + batch_dims : int + The number of batch dimensions. Default is 0. + + Returns + ------- + ret : relax.Expr + The computed result. + + Examples + -------- + .. code-block:: python + + batch_dims = 0 + data = [[0,1],[2,3]] # data_shape = [2, 2] + indices = [[0,0],[1,1]] # indices_shape = [2, 2] + output = [0,3] # output_shape = [2] + + batch_dims = 1 + data = [[[0,1],[2,3]],[[4,5],[6,7]]] # data_shape = [2, 2, 2] + indices = [[1],[0]] # indices_shape = [2, 1] + output = [[2,3],[4,5]] # output_shape = [2, 2] + + """ + return _ffi_api.gather_nd(data, indices, batch_dims) # type: ignore + + def scatter_elements( data: Expr, indices: Expr, updates: Expr, axis: int = 0, reduction: str = "update" ): diff --git a/python/tvm/relax/transform/legalize_ops/manipulate.py b/python/tvm/relax/transform/legalize_ops/manipulate.py index 163085a07c34..55bc2772bcce 100644 --- a/python/tvm/relax/transform/legalize_ops/manipulate.py +++ b/python/tvm/relax/transform/legalize_ops/manipulate.py @@ -156,6 +156,22 @@ def _flip(bb: BlockBuilder, call: Call) -> Expr: return bb.call_te(topi.flip, call.args[0], int(call.attrs.axis)) +@register_legalize("relax.gather_elements") +def _gather_elements(bb: BlockBuilder, call: Call) -> Expr: + return bb.call_te(topi.gather, call.args[0], int(call.attrs.axis), call.args[1]) + + +@register_legalize("relax.gather_nd") +def _gather_nd(bb: BlockBuilder, call: Call) -> Expr: + def te_gather_nd(data, indices, batch_dims): + indices_ndim = len(indices.shape) + axes = [indices_ndim - 1] + list(range(indices_ndim - 1)) + indices = topi.transpose(indices, axes) + return topi.gather_nd(data, indices, batch_dims) + + return bb.call_te(te_gather_nd, call.args[0], call.args[1], int(call.attrs.batch_dims)) + + @register_legalize("relax.scatter_elements") def _scatter_elements(bb: BlockBuilder, call: Call) -> Expr: return bb.call_te( diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index 049345fcb10d..ddc534cf6086 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -94,6 +94,8 @@ floor_mod, full, full_like, + gather_elements, + gather_nd, grad, greater, greater_equal, @@ -772,6 +774,8 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "func_ret_struct_info", "func_ret_value", "function", + "gather_elements", + "gather_nd", "gpu", "grad", "greater", diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py index 3b007a632599..31ce09af60ad 100644 --- a/python/tvm/topi/transform.py +++ b/python/tvm/topi/transform.py @@ -525,7 +525,7 @@ def gather(data, axis, indices): return cpp.gather(data, axis, indices) -def gather_nd(a, indices): +def gather_nd(a, indices, batch_dims=0): """Gather elements from a n-dimension array.. Parameters @@ -540,7 +540,7 @@ def gather_nd(a, indices): ------- ret : tvm.te.Tensor """ - return cpp.gather_nd(a, indices) + return cpp.gather_nd(a, indices, batch_dims) def matmul(a, b, transp_a=False, transp_b=False): diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index ba443413025a..f64b3ec4f979 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -1418,6 +1418,169 @@ TVM_REGISTER_OP("relax.flip") .set_attr("FInferStructInfo", InferStructInfoFlip) .set_attr("FPurity", Bool(true)); +/* relax.gather_elements */ +TVM_REGISTER_NODE_TYPE(GatherElementsAttrs); + +Expr gather_elements(Expr data, Expr indices, int axis) { + auto attrs = make_object(); + attrs->axis = Integer(axis); + static const Op& op = Op::Get("relax.gather_elements"); + return Call(op, {data, indices}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.gather_elements").set_body_typed(gather_elements); + +StructInfo InferStructInfoGatherElements(const Call& call, const BlockBuilder& ctx) { + const auto* data_sinfo = GetStructInfoAs(call->args[0]); + const auto* indices_sinfo = GetStructInfoAs(call->args[1]); + const auto* attrs = call->attrs.as(); + + if (data_sinfo == nullptr) { + ctx->ReportFatal( + Diagnostic::Error(call) + << "GatherElements requires the input data to be a Tensor. However, the given one is " + << call->args[0]->struct_info_->GetTypeKey()); + } + if (indices_sinfo == nullptr) { + ctx->ReportFatal( + Diagnostic::Error(call) + << "GatherElements requires the input indices to be a Tensor. However, the given one is " + << call->args[1]->struct_info_->GetTypeKey()); + } + + if (!indices_sinfo->IsUnknownDtype() && !indices_sinfo->dtype.is_int()) { + ctx->ReportFatal( + Diagnostic::Error(call) + << "GatherElements requires the input indices to have int64 dtype. However, the " + << "given indices dtype is " << indices_sinfo->dtype); + } + + if (data_sinfo->IsUnknownNdim() || indices_sinfo->IsUnknownNdim()) { + return TensorStructInfo(data_sinfo->dtype, kUnknownNDim, data_sinfo->vdevice); + } + + int axis = attrs->axis.IntValue(); + if (axis < -data_sinfo->ndim || axis >= data_sinfo->ndim) { + ctx->ReportFatal(Diagnostic::Error(call) + << "GatherElements requires axis to be within the input dimension range [" + << -data_sinfo->ndim << ", " << data_sinfo->ndim - 1 << "]. However, the " + << "given axis is " << axis); + } + + if (data_sinfo->ndim != indices_sinfo->ndim) { + ctx->ReportFatal(Diagnostic::Error(call) + << "GatherElements requires data and indices to have the same rank. However, " + << "data rank is " << data_sinfo->ndim << " while indices rank is " + << indices_sinfo->ndim); + } + if (indices_sinfo->shape.defined()) { + return TensorStructInfo(indices_sinfo->shape.value(), data_sinfo->dtype, data_sinfo->vdevice); + } + return TensorStructInfo(data_sinfo->dtype, indices_sinfo->ndim, data_sinfo->vdevice); +} + +TVM_REGISTER_OP("relax.gather_elements") + .set_attrs_type() + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("indices", "Tensor", "The indices tensor.") + .set_attr("FInferStructInfo", InferStructInfoGatherElements) + .set_attr("FPurity", Bool(true)); + +/* relax.gather_nd */ +TVM_REGISTER_NODE_TYPE(GatherNDAttrs); + +Expr gather_nd(Expr data, Expr indices, int batch_dims) { + auto attrs = make_object(); + attrs->batch_dims = Integer(batch_dims); + static const Op& op = Op::Get("relax.gather_nd"); + return Call(op, {data, indices}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.gather_nd").set_body_typed(gather_nd); + +StructInfo InferStructInfoGatherND(const Call& call, const BlockBuilder& ctx) { + const auto* data_sinfo = GetStructInfoAs(call->args[0]); + const auto* indices_sinfo = GetStructInfoAs(call->args[1]); + const auto* attrs = call->attrs.as(); + + if (data_sinfo == nullptr) { + ctx->ReportFatal( + Diagnostic::Error(call) + << "GatherND requires the input data to be a Tensor. However, the given one is " + << call->args[0]->struct_info_->GetTypeKey()); + } + if (indices_sinfo == nullptr) { + ctx->ReportFatal( + Diagnostic::Error(call) + << "GatherND requires the input indices to be a Tensor. However, the given one is " + << call->args[1]->struct_info_->GetTypeKey()); + } + ICHECK_GE(attrs->batch_dims.IntValue(), 0); + int batch_dims = attrs->batch_dims.IntValue(); + int input_dims = data_sinfo->ndim; + if (!indices_sinfo->IsUnknownDtype() && indices_sinfo->dtype != DataType::Int(64)) { + ctx->ReportFatal(Diagnostic::Error(call) + << "GatherND requires the input indices to have int64 dtype. However, the " + << "given indices dtype is " << indices_sinfo->dtype); + } + + if (data_sinfo->IsUnknownNdim() || indices_sinfo->IsUnknownNdim()) { + return TensorStructInfo(data_sinfo->dtype, kUnknownNDim, data_sinfo->vdevice); + } + + if (batch_dims < 0 || batch_dims > data_sinfo->ndim) { + ctx->ReportFatal( + Diagnostic::Error(call) + << "GatherND batch_dims must be in range [0, data.ndim]. However, got batch_dims=" + << batch_dims << ", data.ndim=" << input_dims); + } + + if (batch_dims > indices_sinfo->ndim - 1) { + ctx->ReportFatal(Diagnostic::Error(call) + << "GatherND batch_dims cannot exceed indices.ndim-1. However, got batch_dims=" + << batch_dims << ", indices.ndim=" << indices_sinfo->ndim); + } + + // Check if indices shape is known + const auto* indices_shape = indices_sinfo->shape.as(); + const auto* data_shape = data_sinfo->shape.as(); + if (!indices_shape || !indices_shape->values.back()->IsInstance()) { + return TensorStructInfo(data_sinfo->dtype, kUnknownNDim, data_sinfo->vdevice); + } + int l = indices_shape->values.back().as()->value; + int output_ndim = indices_sinfo->ndim + input_dims - l - 1 - batch_dims; + if (!data_shape) { + return TensorStructInfo(data_sinfo->dtype, output_ndim, data_sinfo->vdevice); + } + + // In this condition, all input shapes are known + Array out_shape; + if (l > input_dims - batch_dims) { + ctx->ReportFatal(Diagnostic::Error(call) + << "GatherND requires the last dimension of indices to be less than or " + "equal to the rank of data minus batch_dims. However, the given shapes are " + << "indices: " << ShapeExpr(indices_shape->values) << ", data: " + << ShapeExpr(data_shape->values) << ", with batch_dims=" << batch_dims); + } + for (int i = 0; i < indices_sinfo->ndim - 1; ++i) { + out_shape.push_back(indices_shape->values[i]); + } + for (int i = batch_dims + l; i < input_dims; ++i) { + out_shape.push_back(data_shape->values[i]); + } + ICHECK_EQ(out_shape.size(), output_ndim); + return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype, data_sinfo->vdevice); +} + +TVM_REGISTER_OP("relax.gather_nd") + .set_attrs_type() + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("indices", "Tensor", "The indices tensor.") + .set_attr("FInferStructInfo", InferStructInfoGatherND) + .set_attr("FPurity", Bool(true)); + /* relax.scatter_elements */ TVM_REGISTER_NODE_TYPE(ScatterElementsAttrs); diff --git a/src/relax/op/tensor/manipulate.h b/src/relax/op/tensor/manipulate.h index 010ceb663ef3..1a0c7ddbc76c 100644 --- a/src/relax/op/tensor/manipulate.h +++ b/src/relax/op/tensor/manipulate.h @@ -174,6 +174,32 @@ Expr tile(Expr data, Array repeats); */ Expr flip(Expr data, Integer axis); +/*! + * \brief Gather elements from a tensor using indices. + * \param data The input tensor. + * \param indices The indices tensor, must have integer type. + * \param axis The axis along which to index. Default is 0. + * \return The computed result. + * + * \note The shape of indices must match the shape of data, except at dimension axis + * where it must just be not null. The output will have the same shape as indices. + */ +Expr gather_elements(Expr data, Expr indices, int axis = 0); + +/*! + * \brief Gather values from a tensor using N-dimensional indices. + * \param data The input tensor. + * \param indices The indices tensor, must have integer type. + * \param batch_dims The number of batch dimensions. Default is 0. + * \return The computed result. + * + * \note For batch_dims > 0, the first batch_dims dimensions of data and indices must be equal. + * The last dimension of indices indicates the depth of each index vector. + * The output shape is batch_dims + indices.shape[:-1] + data.shape[batch_dims + + * indices.shape[-1]:] + */ +Expr gather_nd(Expr data, Expr indices, int batch_dims = 0); + /*! * \brief Scatter updates into an array according to indices. * \param data The input tensor. diff --git a/src/topi/transform.cc b/src/topi/transform.cc index a84e3dce500c..72291fd96123 100644 --- a/src/topi/transform.cc +++ b/src/topi/transform.cc @@ -135,7 +135,8 @@ TVM_REGISTER_GLOBAL("topi.gather").set_body([](TVMArgs args, TVMRetValue* rv) { }); TVM_REGISTER_GLOBAL("topi.gather_nd").set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = gather_nd(args[0], args[1]); + int batch_dims = args[2]; + *rv = gather_nd(args[0], args[1], batch_dims); }); TVM_REGISTER_GLOBAL("topi.unravel_index").set_body([](TVMArgs args, TVMRetValue* rv) { diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index 6f74957a0781..b1f9092b9f8b 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -544,6 +544,68 @@ def _verify_gather(data_shape, indices, out_shape, axis=0): _verify_gather([3, 3], [[0, 2]], [3, 1, 2], 1) +@pytest.mark.parametrize( + "data_shape, indices_shape, axis", + [ + ([3, 4, 5], [1, 4, 5], 0), + ([3, 4, 5], [3, 2, 5], 1), + ([3, 4, 5], [3, 4, 2], 2), + ], +) +def test_gather_elements(data_shape, indices_shape, axis): + gather_elements_node = helper.make_node("GatherElements", ["data", "indices"], ["y"], axis=axis) + + graph = helper.make_graph( + [gather_elements_node], + "gather_elements_test", + inputs=[ + helper.make_tensor_value_info("data", TensorProto.FLOAT, data_shape), + helper.make_tensor_value_info("indices", TensorProto.INT64, indices_shape), + ], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, indices_shape)], + ) + + model = helper.make_model(graph, producer_name="gather_elements_test") + input_values = { + "data": np.random.randn(*data_shape).astype("float32"), + "indices": np.random.randint(0, data_shape[axis], indices_shape).astype("int64"), + } + check_correctness(model, inputs=input_values) + + +@pytest.mark.parametrize( + "data_shape, indices_shape, batch_dims", + [ + ([2, 2], [2, 2], 0), + ([2, 2], [2, 1], 0), + ([2, 2, 2], [1], 0), + ([2, 2, 2], [2, 2], 0), + ([2, 2, 2], [2, 1, 2], 0), + ([2, 2, 2], [2, 2], 1), + ([2, 2, 2], [2, 1], 1), + ], +) +def test_gather_nd(data_shape, indices_shape, batch_dims): + gather_nd_node = helper.make_node("GatherND", ["data", "indices"], ["y"], batch_dims=batch_dims) + + graph = helper.make_graph( + [gather_nd_node], + "gather_nd_test", + inputs=[ + helper.make_tensor_value_info("data", TensorProto.FLOAT, data_shape), + helper.make_tensor_value_info("indices", TensorProto.INT64, indices_shape), + ], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, None)], + ) + + model = helper.make_model(graph, producer_name="gather_nd_test") + input_values = { + "data": np.random.randn(*data_shape).astype("float32"), + "indices": np.random.randint(0, 2, indices_shape).astype("int64"), + } + check_correctness(model, inputs=input_values) + + @pytest.mark.parametrize("axis", [0, 1, 2]) @pytest.mark.parametrize(("name", "opset"), [("Scatter", 10), ("ScatterElements", 11)]) def test_scatter(axis: int, name: str, opset: int): diff --git a/tests/python/relax/test_op_manipulate.py b/tests/python/relax/test_op_manipulate.py index f6aefc859114..23ab6780cf7b 100644 --- a/tests/python/relax/test_op_manipulate.py +++ b/tests/python/relax/test_op_manipulate.py @@ -3205,6 +3205,133 @@ def test_flip_infer_struct_info_wrong_inputs(): bb.normalize(relax.op.flip(x0, axis=3)) +def test_gather_elements_infer_struct_info(): + bb = relax.BlockBuilder() + vdev0 = VDevice("llvm") + x0 = relax.Var("x", R.Tensor((2, 3, 4), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=3)) + x2 = relax.Var("x", R.Tensor("float32")) + x3 = relax.Var("x", R.Tensor((2, 3, 4), "float32", vdev0)) + i0 = relax.Var("i", R.Tensor((2, 3, 4), "int64")) + i1 = relax.Var("i", R.Tensor((2, 3, 4))) + i2 = relax.Var("i", R.Tensor("int64", ndim=3)) + i3 = relax.Var("i", R.Tensor(ndim=3)) + i4 = relax.Var("i", R.Tensor((2, 3, 4), "int64", vdev0)) + + _check_inference( + bb, relax.op.gather_elements(x0, i0, axis=1), relax.TensorStructInfo((2, 3, 4), "float32") + ) + _check_inference( + bb, + relax.op.gather_elements(x3, i4, axis=1), + relax.TensorStructInfo((2, 3, 4), "float32", vdev0), + ) + _check_inference( + bb, + relax.op.gather_elements(x1, i0, axis=1), + relax.TensorStructInfo((2, 3, 4), dtype="float32"), + ) + _check_inference( + bb, + relax.op.gather_elements(x2, i0, axis=0), + relax.TensorStructInfo(dtype="float32", ndim=-1), + ) + _check_inference( + bb, relax.op.gather_elements(x0, i1, axis=1), relax.TensorStructInfo((2, 3, 4), "float32") + ) + _check_inference( + bb, + relax.op.gather_elements(x1, i2, axis=1), + relax.TensorStructInfo(dtype="float32", ndim=3), + ) + _check_inference( + bb, + relax.op.gather_elements(x2, i3, axis=0), + relax.TensorStructInfo(dtype="float32", ndim=-1), + ) + + +def test_gather_elements_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + a = tir.Var("a", "int64") + b = tir.Var("b", "int64") + x = relax.Var("x", R.Tensor((a, b), "float32")) + i = relax.Var("i", R.Tensor((a, b), "int64")) + + _check_inference( + bb, relax.op.gather_elements(x, i, axis=1), relax.TensorStructInfo((a, b), "float32") + ) + + +def test_gather_elements_infer_struct_info_wrong_inputs(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 4), "float32")) + x1 = relax.Var("x", R.Tensor((2, 3), "float32")) + i0 = relax.Var("i", R.Tensor((2, 3, 4), "int64")) + i1 = relax.Var("i", R.Tensor((2, 3), "int64")) + i2 = relax.Var("i", R.Tensor((2, 3, 4), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.gather_elements(x0, i0, axis=3)) + with pytest.raises(TVMError): + bb.normalize(relax.op.gather_elements(x0, i1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.gather_elements(x1, i0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.gather_elements(x0, i2)) + + +def test_gather_nd_infer_struct_info(): + bb = relax.BlockBuilder() + vdev0 = VDevice("llvm") + x0 = relax.Var("x", R.Tensor((2, 3, 4), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=3)) + x2 = relax.Var("x", R.Tensor("float32")) + x3 = relax.Var("x", R.Tensor((2, 3, 4), "float32", vdev0)) + i0 = relax.Var("i", R.Tensor((2, 2), "int64")) + i1 = relax.Var("i", R.Tensor((2, 2))) + i2 = relax.Var("i", R.Tensor("int64", ndim=2)) + i3 = relax.Var("i", R.Tensor(ndim=2)) + i4 = relax.Var("i", R.Tensor((2, 2), "int64", vdev0)) + + _check_inference(bb, relax.op.gather_nd(x0, i0), relax.TensorStructInfo((2, 4), "float32")) + _check_inference( + bb, relax.op.gather_nd(x3, i4), relax.TensorStructInfo((2, 4), "float32", vdev0) + ) + _check_inference( + bb, relax.op.gather_nd(x1, i0), relax.TensorStructInfo(dtype="float32", ndim=2) + ) + _check_inference( + bb, relax.op.gather_nd(x2, i0), relax.TensorStructInfo(dtype="float32", ndim=-1) + ) + _check_inference(bb, relax.op.gather_nd(x0, i1), relax.TensorStructInfo((2, 4), "float32")) + _check_inference(bb, relax.op.gather_nd(x1, i2), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.gather_nd(x2, i3), relax.TensorStructInfo(dtype="float32")) + + +def test_gather_nd_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + a = tir.Var("a", "int64") + b = tir.Var("b", "int64") + c = tir.Var("c", "int64") + x = relax.Var("x", R.Tensor((a, b, c), "float32")) + i = relax.Var("i", R.Tensor((2, 2), "int64")) + + _check_inference(bb, relax.op.gather_nd(x, i), relax.TensorStructInfo((2, c), "float32")) + + +def test_gather_nd_infer_struct_info_wrong_inputs(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 4), "float32")) + i0 = relax.Var("i", R.Tensor((2, 4), "int64")) # indices too long + i1 = relax.Var("i", R.Tensor((2, 2), "float32")) # wrong dtype + + with pytest.raises(TVMError): + bb.normalize(relax.op.gather_nd(x0, i0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.gather_nd(x0, i1)) + + def test_scatter_elements_infer_struct_info(): bb = relax.BlockBuilder() vdev0 = VDevice("llvm")