diff --git a/include/tvm/relax/attrs/manipulate.h b/include/tvm/relax/attrs/manipulate.h index 67f99d9b417e..943d2f4d0d71 100644 --- a/include/tvm/relax/attrs/manipulate.h +++ b/include/tvm/relax/attrs/manipulate.h @@ -182,6 +182,20 @@ struct GatherNDAttrs : public tvm::AttrsNode { } }; // struct GatherNDAttrs +/*! \brief Attributes used in index_put operator */ +struct IndexPutAttrs : public tvm::AttrsNode { + bool accumulate; + + TVM_DECLARE_ATTRS(IndexPutAttrs, "relax.attrs.IndexPutAttrs") { + TVM_ATTR_FIELD(accumulate) + .set_default(false) + .describe( + "Whether to accumulate (add) values rather than replace. " + "If true, performs tensor[indices] += values, " + "otherwise performs tensor[indices] = values."); + } +}; // struct IndexPutAttrs + /*! \brief Attributes used in scatter_elements operators */ struct ScatterElementsAttrs : public tvm::AttrsNode { Integer axis; diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 33f6ffc3132e..5dd78be4837d 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1118,6 +1118,23 @@ def _gather(self, node: fx.Node) -> relax.Var: index = self.env[node.args[2]] return self.block_builder.emit(relax.op.gather_elements(x, index, axis=dim)) + def _index_put(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + tensor = args[0] + indices = args[1] if len(args) > 1 else node.kwargs.get("indices") + values = args[2] if len(args) > 2 else node.kwargs.get("values") + accumulate = args[3] if len(args) > 3 else node.kwargs.get("accumulate", False) + + if indices is None or values is None: + raise ValueError("'indices and values' arguments are required for index_put operation") + + if not isinstance(accumulate, bool): + raise TypeError("'accumulate' must be a boolean value, got {}".format(type(accumulate))) + + if isinstance(indices, (list, tuple)): + indices = relax.Tuple(indices) + return self.block_builder.emit(relax.op.index_put(tensor, indices, values, accumulate)) + def _index_tensor(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) indices = args[1] diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 0434712050ed..9f097ef8df11 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -433,6 +433,7 @@ def create_convert_map( "flip.default": self._flip, "gather.default": self._gather, "index.Tensor": self._index_tensor, + "index_put_.default": self._index_put, "narrow.default": self._narrow, "permute.default": self._permute, "repeat.default": self._repeat, diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 55abf20fcc03..6cb9db9b9f3d 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -801,6 +801,7 @@ def create_convert_map( "flatten": self._flatten, "flip": self._flip, "gather": self._gather, + "index_put_": self._index_put, "narrow": self._narrow, "numel": self._numel, "permute": self._permute, diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py index 097313a33da2..7b8c34b6415f 100644 --- a/python/tvm/relax/op/__init__.py +++ b/python/tvm/relax/op/__init__.py @@ -95,6 +95,7 @@ flip, gather_elements, gather_nd, + index_put, index_tensor, layout_transform, one_hot, diff --git a/python/tvm/relax/op/manipulate.py b/python/tvm/relax/op/manipulate.py index a693adf4325c..13334d1479d9 100644 --- a/python/tvm/relax/op/manipulate.py +++ b/python/tvm/relax/op/manipulate.py @@ -595,6 +595,57 @@ def index_tensor(data: Expr, indices: Union[Expr, List[Expr]]) -> Expr: return _ffi_api.index_tensor(data, indices) # type: ignore +def index_put( + data: Expr, + indices: Union[Expr, Tuple[Expr]], + values: Expr, + accumulate: bool = False, +) -> Expr: + """This operation updates values in `data` at positions + specified by `indices` with corresponding values from `values`. The `indices` is a tuple + of tensors where each tensor corresponds to a dimension in `data`. + When `accumulate` is True, the operation performs accumulation (addition) rather than + replacement. The `reduction` parameter allows specifying different reduction operations. + Parameters + ---------- + data : relax.Expr + The input tensor to be modified + indices : Union[Expr, Tuple[Expr]] + Tuple of index tensors (one for each dimension) specifying positions to update + values : relax.Expr + Values to place at the specified indices + accumulate : bool + Whether to accumulate (add) values rather than replace (default: False) + + Returns + ------- + result : relax.Expr + A new tensor with the same shape as data but with specified positions updated + Examples + -------- + .. code-block:: python + # inputs + data = torch.zeros(3, 3) + indices = (torch.tensor([0, 2]), torch.tensor([1, 1])) + values = torch.tensor([1.0, 2.0]) + # output + output = [ + [0.0, 1.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 2.0, 0.0], + ] + # with accumulate=True + output = [ + [0.0, 1.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 3.0, 0.0], + ] + """ + if not isinstance(indices, (list, tuple)): + indices = RxTuple(indices) + return _ffi_api.index_put(data, indices, values, accumulate) # type: ignore + + def scatter_elements( data: Expr, indices: Expr, updates: Expr, axis: int = 0, reduction: str = "update" ): diff --git a/python/tvm/relax/op/op_attrs.py b/python/tvm/relax/op/op_attrs.py index fda4258a093b..fe527e38e8a8 100644 --- a/python/tvm/relax/op/op_attrs.py +++ b/python/tvm/relax/op/op_attrs.py @@ -144,6 +144,11 @@ class StackAttrs(Attrs): """Attributes for concat operator""" +@tvm._ffi.register_object("relax.attrs.IndexPutAttrs") +class IndexPutAttrs(Attrs): + """Attributes for index_put operator""" + + @tvm._ffi.register_object("relax.attrs.LayoutTransformAttrs") class LayoutTransformAttrs(Attrs): """Attributes used in layout_transform operator""" diff --git a/python/tvm/relax/transform/legalize_ops/manipulate.py b/python/tvm/relax/transform/legalize_ops/manipulate.py index 84baa887d9aa..a66b60c0134b 100644 --- a/python/tvm/relax/transform/legalize_ops/manipulate.py +++ b/python/tvm/relax/transform/legalize_ops/manipulate.py @@ -193,6 +193,28 @@ def _index_tensor(bb: BlockBuilder, call: Call) -> Expr: return bb.call_te(topi.index_tensor, call.args[0], fields) +@register_legalize("relax.index_put") +def _index_put(bb: BlockBuilder, call: Call) -> Expr: + data = call.args[0] + indices = call.args[1] + values = call.args[2] + accumulate = call.attrs.accumulate + + # If indices is a Tuple, unpack it into individual tensors + if isinstance(indices, relax.Tuple): + indices_list = [indices.fields[i] for i in range(len(indices.fields))] + else: + indices_list = [indices] + + return bb.call_te( + topi.index_put, + data, + indices_list, + values, + accumulate=accumulate, + ) + + @register_legalize("relax.scatter_elements") def _scatter_elements(bb: BlockBuilder, call: Call) -> Expr: return bb.call_te( diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index 22b00cd70416..d2952ed8e0d7 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -100,6 +100,7 @@ greater, greater_equal, hint_on_device, + index_put, image, index_tensor, invoke_closure, @@ -785,6 +786,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "greater_equal", "hexagon", "hint_on_device", + "index_put", "image", "index_tensor", "invoke_closure", diff --git a/python/tvm/topi/__init__.py b/python/tvm/topi/__init__.py index 1de6941c9923..fa4e98a89a42 100644 --- a/python/tvm/topi/__init__.py +++ b/python/tvm/topi/__init__.py @@ -33,6 +33,7 @@ from .math import * from .tensor import * from .generic_op_impl import * +from .index_put import * from .reduction import * from .transform import * from .broadcast import * diff --git a/python/tvm/topi/index_put.py b/python/tvm/topi/index_put.py new file mode 100644 index 000000000000..f51c6718ab99 --- /dev/null +++ b/python/tvm/topi/index_put.py @@ -0,0 +1,117 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contrir_builderutor license agreements. See the NOTICE file +# distrir_builderuted with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distrir_builderuted under the License is distrir_builderuted on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""IndexPut operator""" +from tvm import te +from tvm import tir +from . import utils + + +def index_put(data, indices, values, accumulate=False): + """Put values into an array according to indices. + + Parameters + ---------- + data : tvm.te.Tensor + The source array to be modified. + + indices : Tuple[tvm.te.Tensor] + Tuple of 1D index tensors (one for each dimension) specifying positions. + + values : tvm.te.Tensor + The values to place at the specified indices. + + accumulate : bool, optional + Whether to accumulate (add) values rather than replace. + If True, performs tensor[indices] += values + If False, performs tensor[indices] = values + Default is False. + + Returns + ------- + ret : tvm.te.Tensor + """ + if not isinstance(indices, (list, tuple)): + indices = [indices] + + # Check indices match data dimensions + if len(indices) != len(data.shape): + raise ValueError( + f"Number of index tensors ({len(indices)}) must match " + f"data dimensions ({len(data.shape)})" + ) + + # Prepare ranges and strides + shape = data.shape + full_range = 1 + for dim in shape: + full_range *= dim + + # Check all indices have same length + index_len = len(indices[0]) + for idx in indices[1:]: + if not utils.equal_const_int(len(idx), index_len): + raise ValueError("All index tensors must have same length") + + def gen_ir(data_ptr, index_ptrs, values_ptr, out_ptr, reduce_func): + ir_builder = tir.ir_builder.create() + + data = ir_builder.buffer_ptr(data_ptr) + indices = [ir_builder.buffer_ptr(idx) for idx in index_ptrs] + values = ir_builder.buffer_ptr(values_ptr) + out = ir_builder.buffer_ptr(out_ptr) + + with ir_builder.for_range(0, full_range, "i", kind="parallel") as i: + out[i] = data[i] + + with ir_builder.for_range(0, index_len, "k", kind="parallel") as k: + # Calculate multi-dimensional index + flat_index = 0 + stride = 1 + for dim in range(len(shape) - 1, -1, -1): + # Get index and shift to positive if needed + idx_val = indices[dim][k] + shifted_idx = idx_val + (idx_val < 0) * shape[dim] + flat_index += shifted_idx * stride + stride *= shape[dim] + + reduce_func(out, flat_index, values[k]) + + return ir_builder.get() + + def update_func(dst_ptr, dst_index, update): + dst_ptr[dst_index] = update + + def add_func(dst_ptr, dst_index, update): + dst_ptr[dst_index] += update + + reduce_func = add_func if accumulate else update_func + + # Prepare input buffers + in_buffers = [data] + in_buffers.extend(indices) + in_buffers.append(values) + + out_buf = tir.decl_buffer(data.shape, data.dtype, "out_buf") + return te.extern( + [data.shape], + in_buffers, + lambda ins, outs: gen_ir(ins[0], ins[1:-1], ins[-1], outs[0], reduce_func), + dtype=data.dtype, + out_buffers=[out_buf], + name="index_put.generic", + tag="index_put.generic", + ) diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index f56135a35bc3..482ebe5cacb6 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -1972,6 +1972,129 @@ TVM_REGISTER_OP("relax.gather_nd") .set_attr("FInferStructInfo", InferStructInfoGatherND) .set_attr("FPurity", Bool(true)); +/* relax.index_put */ +TVM_REGISTER_NODE_TYPE(IndexPutAttrs); + +Expr index_put(Expr data, Expr indices, Expr values, bool accumulate) { + auto attrs = make_object(); + attrs->accumulate = std::move(accumulate); + static const Op& op = Op::Get("relax.index_put"); + return Call(op, {data, indices, values}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.index_put").set_body_typed(index_put); + +StructInfo InferStructInfoIndexPut(const Call& call, const BlockBuilder& ctx) { + const auto* data_sinfo = GetStructInfoAs(call->args[0]); + const auto* values_sinfo = GetStructInfoAs(call->args[2]); + + auto diag_def = [&](const TensorStructInfoNode* sinfo, String name, String type_key) { + if (sinfo == nullptr) { + ctx->ReportFatal(Diagnostic::Error(call) + << "IndexPut requires the input " << name + << " to be a Tensor. However, the given one is " << type_key); + } + }; + + diag_def(data_sinfo, "data", call->args[0]->struct_info_->GetTypeKey()); + diag_def(values_sinfo, "values", call->args[2]->struct_info_->GetTypeKey()); + + // Handle indices: either a single tensor or a tuple of tensors + Array indices_tensors; + + if (const auto* tuple_sinfo = GetStructInfoAs(call->args[1])) { + // Indices is a tuple of tensors + for (size_t i = 0; i < tuple_sinfo->fields.size(); ++i) { + const auto* tensor_sinfo = tuple_sinfo->fields[i].as(); + if (tensor_sinfo == nullptr) { + ctx->ReportFatal(Diagnostic::Error(call) + << "IndexPut requires each index in the indices tuple to be a Tensor. " + << "However, element " << i << " is " + << tuple_sinfo->fields[i]->GetTypeKey()); + } + indices_tensors.push_back(GetRef(tensor_sinfo)); + } + } else if (const auto* tensor_sinfo = GetStructInfoAs(call->args[1])) { + // Indices is a single tensor + indices_tensors.push_back(GetRef(tensor_sinfo)); + } else { + ctx->ReportFatal(Diagnostic::Error(call) + << "IndexPut requires indices to be a Tensor or a tuple of Tensors. " + << "However, the given one is " << call->args[1]->struct_info_->GetTypeKey()); + } + + if (data_sinfo->IsUnknownNdim()) { + return TensorStructInfo(data_sinfo->dtype, kUnknownNDim, data_sinfo->vdevice); + } + + // Validate each index tensor + for (size_t i = 0; i < indices_tensors.size(); ++i) { + const auto& tensor_sinfo = indices_tensors[i]; + if (!tensor_sinfo->IsUnknownNdim() && tensor_sinfo->ndim != 1) { + ctx->ReportFatal(Diagnostic::Error(call) + << "IndexPut requires each index tensor to be 1D. " + << "However, index tensor " << i << " has ndim=" << tensor_sinfo->ndim); + } + if (tensor_sinfo->IsUnknownDtype()) { + LOG(WARNING) << "Data type of index tensor " << i + << " has not been specified. Assume it has an integer type."; + } else if (!(tensor_sinfo->dtype.is_int() || tensor_sinfo->dtype.is_uint())) { + ctx->ReportFatal(Diagnostic::Error(call) + << "IndexPut requires each index tensor to have integer dtype. " + << "However, index tensor " << i << " has dtype=" << tensor_sinfo->dtype); + } + } + + // Check that the number of index tensors matches data dimensions + if (!data_sinfo->IsUnknownNdim() && + indices_tensors.size() != static_cast(data_sinfo->ndim)) { + ctx->ReportFatal(Diagnostic::Error(call) + << "IndexPut requires the number of index tensors (" << indices_tensors.size() + << ") to match the data tensor dimensions (" << data_sinfo->ndim << ")"); + } + + // Check data and values dtype compatibility + if (data_sinfo->IsUnknownDtype() || values_sinfo->IsUnknownDtype()) { + auto diag_dtype = [&](const TensorStructInfoNode* sinfo, String name) { + if (sinfo->IsUnknownDtype()) { + LOG(WARNING) << "Data type of " << name + << " has not been specified. Assume it has an integer type."; + } + }; + diag_dtype(data_sinfo, "data"); + diag_dtype(values_sinfo, "values"); + } else if (data_sinfo->dtype != values_sinfo->dtype) { + ctx->ReportFatal(Diagnostic::Error(call) + << "IndexPut requires the input data to have the same type as values. " + << "However, the given types are data: " << data_sinfo->dtype + << ", values: " << values_sinfo->dtype); + } + + // Check values shape compatibility + const auto* values_shape = values_sinfo->shape.as(); + if (values_shape) { + if (values_sinfo->ndim != 1) { + LOG(WARNING) << "IndexPut typically expects values to be 1D, but got ndim=" + << values_sinfo->ndim; + } + } + + const auto* data_shape = data_sinfo->shape.as(); + if (data_shape) { + return TensorStructInfo(ShapeExpr(data_shape->values), data_sinfo->dtype, data_sinfo->vdevice); + } + return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim, data_sinfo->vdevice); +} + +TVM_REGISTER_OP("relax.index_put") + .set_attrs_type() + .set_num_inputs(3) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("indices", "Tensor", "The indices tensor(s).") + .add_argument("values", "Tensor", "The values to put.") + .set_attr("FInferStructInfo", InferStructInfoIndexPut) + .set_attr("FPurity", Bool(true)); + /* relax.scatter_elements */ TVM_REGISTER_NODE_TYPE(ScatterElementsAttrs); diff --git a/src/relax/op/tensor/manipulate.h b/src/relax/op/tensor/manipulate.h index 4580f9191bb8..2e4c92c15052 100644 --- a/src/relax/op/tensor/manipulate.h +++ b/src/relax/op/tensor/manipulate.h @@ -218,6 +218,19 @@ Expr gather_nd(Expr data, Expr indices, int batch_dims = 0); */ Expr index_tensor(Expr data, Expr indices); +/*! + * \brief Put values into an array according to indices. + * \param data The input tensor to be modified. + * \param indices The index positions where values should be placed. + * This should be a tuple of 1D tensors (one for each dimension). + * \param values The values to place at the specified indices. + * \param accumulate Whether to accumulate (add) values rather than replace. + * If true, equivalent to tensor[indices] += values. + * If false, equivalent to tensor[indices] = values. + * \return The computed result with values placed at specified indices. + */ +Expr index_put(Expr data, Expr indices, Expr values, bool accumulate = false); + /*! * \brief Scatter updates into an array according to indices. * \param data The input tensor. diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 108617991b1f..4566654449cc 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -4169,6 +4169,181 @@ def main( verify_model(Gather3(), example_args, {}, Expected3) +def test_index_put(): + # Test case 1: 1D input + class IndexPut1D(Module): + def forward(self, data, indices_0, values): + indices_tuple = (indices_0,) + return data.index_put_(indices_tuple, values, accumulate=False) + + example_args_1d = ( + torch.randn(64, dtype=torch.float32), + torch.randint(0, 64, (128,), dtype=torch.int64), + torch.randn(128, dtype=torch.float32), + ) + + @I.ir_module + class Expected1D: + @R.function + def main( + data: R.Tensor((64,), dtype="float32"), + indices_0: R.Tensor((128,), dtype="int64"), + values: R.Tensor((128,), dtype="float32"), + ) -> R.Tuple(R.Tensor((64,), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((64,), dtype="float32") = R.index_put( + data, R.tuple(indices_0), values, accumulate=False + ) + gv: R.Tuple(R.Tensor((64,), dtype="float32")) = (lv,) + R.output(gv) + return gv + + # Test case 2: 2D input + class IndexPut2D(Module): + def forward(self, data, indices_0, indices_1, values): + indices_tuple = (indices_0, indices_1) + return data.index_put_(indices_tuple, values, accumulate=False) + + example_args_2d = ( + torch.randn(32, 64, dtype=torch.float32), + torch.randint(0, 32, (128,), dtype=torch.int64), + torch.randint(0, 64, (128,), dtype=torch.int64), + torch.randn(128, dtype=torch.float32), + ) + + @I.ir_module + class Expected2D: + @R.function + def main( + data: R.Tensor((32, 64), dtype="float32"), + indices_0: R.Tensor((128,), dtype="int64"), + indices_1: R.Tensor((128,), dtype="int64"), + values: R.Tensor((128,), dtype="float32"), + ) -> R.Tuple(R.Tensor((32, 64), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((32, 64), dtype="float32") = R.index_put( + data, R.tuple(indices_0, indices_1), values, accumulate=False + ) + gv: R.Tuple(R.Tensor((32, 64), dtype="float32")) = (lv,) + R.output(gv) + return gv + + # Test case 3: 3D input + class IndexPut3D(Module): + def forward(self, data, indices_0, indices_1, indices_2, values): + indices_tuple = (indices_0, indices_1, indices_2) + return data.index_put_(indices_tuple, values, accumulate=False) + + example_args_3d = ( + torch.randn(16, 32, 64, dtype=torch.float32), + torch.randint(0, 16, (128,), dtype=torch.int64), + torch.randint(0, 32, (128,), dtype=torch.int64), + torch.randint(0, 64, (128,), dtype=torch.int64), + torch.randn(128, dtype=torch.float32), + ) + + @I.ir_module + class Expected3D: + @R.function + def main( + data: R.Tensor((16, 32, 64), dtype="float32"), + indices_0: R.Tensor((128,), dtype="int64"), + indices_1: R.Tensor((128,), dtype="int64"), + indices_2: R.Tensor((128,), dtype="int64"), + values: R.Tensor((128,), dtype="float32"), + ) -> R.Tuple(R.Tensor((16, 32, 64), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((16, 32, 64), dtype="float32") = R.index_put( + data, R.tuple(indices_0, indices_1, indices_2), values, accumulate=False + ) + gv: R.Tuple(R.Tensor((16, 32, 64), dtype="float32")) = (lv,) + R.output(gv) + return gv + + # Test case 4: 4D input + class IndexPut4D(Module): + def forward(self, data, indices_0, indices_1, indices_2, indices_3, values): + indices_tuple = (indices_0, indices_1, indices_2, indices_3) + return data.index_put_(indices_tuple, values, accumulate=False) + + example_args_4d = ( + torch.randn(8, 16, 32, 64, dtype=torch.float32), + torch.randint(0, 8, (128,), dtype=torch.int64), + torch.randint(0, 16, (128,), dtype=torch.int64), + torch.randint(0, 32, (128,), dtype=torch.int64), + torch.randint(0, 64, (128,), dtype=torch.int64), + torch.randn(128, dtype=torch.float32), + ) + + @I.ir_module + class Expected4D: + @R.function + def main( + data: R.Tensor((8, 16, 32, 64), dtype="float32"), + indices_0: R.Tensor((128,), dtype="int64"), + indices_1: R.Tensor((128,), dtype="int64"), + indices_2: R.Tensor((128,), dtype="int64"), + indices_3: R.Tensor((128,), dtype="int64"), + values: R.Tensor((128,), dtype="float32"), + ) -> R.Tuple(R.Tensor((8, 16, 32, 64), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((8, 16, 32, 64), dtype="float32") = R.index_put( + data, + R.tuple(indices_0, indices_1, indices_2, indices_3), + values, + accumulate=False, + ) + gv: R.Tuple(R.Tensor((8, 16, 32, 64), dtype="float32")) = (lv,) + R.output(gv) + return gv + + # Test case 5: 5D input + class IndexPut5D(Module): + def forward(self, data, indices_0, indices_1, indices_2, indices_3, indices_4, values): + indices_tuple = (indices_0, indices_1, indices_2, indices_3, indices_4) + return data.index_put_(indices_tuple, values, accumulate=False) + + example_args_5d = ( + torch.randn(4, 8, 16, 32, 64, dtype=torch.float32), + torch.randint(0, 4, (128,), dtype=torch.int64), + torch.randint(0, 8, (128,), dtype=torch.int64), + torch.randint(0, 16, (128,), dtype=torch.int64), + torch.randint(0, 32, (128,), dtype=torch.int64), + torch.randint(0, 64, (128,), dtype=torch.int64), + torch.randn(128, dtype=torch.float32), + ) + + @I.ir_module + class Expected5D: + @R.function + def main( + data: R.Tensor((4, 8, 16, 32, 64), dtype="float32"), + indices_0: R.Tensor((128,), dtype="int64"), + indices_1: R.Tensor((128,), dtype="int64"), + indices_2: R.Tensor((128,), dtype="int64"), + indices_3: R.Tensor((128,), dtype="int64"), + indices_4: R.Tensor((128,), dtype="int64"), + values: R.Tensor((128,), dtype="float32"), + ) -> R.Tuple(R.Tensor((4, 8, 16, 32, 64), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((4, 8, 16, 32, 64), dtype="float32") = R.index_put( + data, + R.tuple(indices_0, indices_1, indices_2, indices_3, indices_4), + values, + accumulate=False, + ) + gv: R.Tuple(R.Tensor((4, 8, 16, 32, 64), dtype="float32")) = (lv,) + R.output(gv) + return gv + + # Run verification for each case + verify_model(IndexPut1D(), example_args_1d, {}, Expected1D) + verify_model(IndexPut2D(), example_args_2d, {}, Expected2D) + verify_model(IndexPut3D(), example_args_3d, {}, Expected3D) + verify_model(IndexPut4D(), example_args_4d, {}, Expected4D) + verify_model(IndexPut5D(), example_args_5d, {}, Expected5D) + + def test_flip(): class Flip0(Module): def forward(self, data): diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index cb69398e0a00..a77c5d82b58e 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -4404,6 +4404,177 @@ def main( verify_model(Gather3(), [([2, 3], "float32"), ([2, 3], "int32")], {}, Expected3) +def test_index_put(): + # Test case 1: 1D input + class IndexPut1D(Module): + def forward(self, data, indices_0, values): + indices_tuple = (indices_0,) + return data.index_put_(indices_tuple, values, accumulate=False) + + input_info_1d = [((64,), "float32"), ((128,), "int64"), ((128,), "float32")] + + @I.ir_module + class Expected1D: + @R.function + def main( + data: R.Tensor((64,), dtype="float32"), + indices_0: R.Tensor((128,), dtype="int64"), + values: R.Tensor((128,), dtype="float32"), + ) -> R.Tensor((64,), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((64,), dtype="float32") = R.index_put( + data, R.tuple(indices_0), values, accumulate=False + ) + gv: R.Tensor((64,), dtype="float32") = lv + R.output(gv) + return gv + + # Test case 2: 2D input + class IndexPut2D(Module): + def forward(self, data, indices_0, indices_1, values): + indices_tuple = (indices_0, indices_1) + return data.index_put_(indices_tuple, values, accumulate=False) + + input_info_2d = [ + ((32, 64), "float32"), + ((128,), "int64"), + ((128,), "int64"), + ((128,), "float32"), + ] + + @I.ir_module + class Expected2D: + @R.function + def main( + data: R.Tensor((32, 64), dtype="float32"), + indices_0: R.Tensor((128,), dtype="int64"), + indices_1: R.Tensor((128,), dtype="int64"), + values: R.Tensor((128,), dtype="float32"), + ) -> R.Tensor((32, 64), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((32, 64), dtype="float32") = R.index_put( + data, R.tuple(indices_0, indices_1), values, accumulate=False + ) + gv: R.Tensor((32, 64), dtype="float32") = lv + R.output(gv) + return gv + + # Test case 3: 3D input + class IndexPut3D(Module): + def forward(self, data, indices_0, indices_1, indices_2, values): + indices_tuple = (indices_0, indices_1, indices_2) + return data.index_put_(indices_tuple, values, accumulate=False) + + input_info_3d = [ + ((16, 32, 64), "float32"), + ((128,), "int64"), + ((128,), "int64"), + ((128,), "int64"), + ((128,), "float32"), + ] + + @I.ir_module + class Expected3D: + @R.function + def main( + data: R.Tensor((16, 32, 64), dtype="float32"), + indices_0: R.Tensor((128,), dtype="int64"), + indices_1: R.Tensor((128,), dtype="int64"), + indices_2: R.Tensor((128,), dtype="int64"), + values: R.Tensor((128,), dtype="float32"), + ) -> R.Tensor((16, 32, 64), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((16, 32, 64), dtype="float32") = R.index_put( + data, R.tuple(indices_0, indices_1, indices_2), values, accumulate=False + ) + gv: R.Tensor((16, 32, 64), dtype="float32") = lv + R.output(gv) + return gv + + # Test case 4: 4D input + class IndexPut4D(Module): + def forward(self, data, indices_0, indices_1, indices_2, indices_3, values): + indices_tuple = (indices_0, indices_1, indices_2, indices_3) + return data.index_put_(indices_tuple, values, accumulate=False) + + input_info_4d = [ + ((8, 16, 32, 64), "float32"), + ((128,), "int64"), + ((128,), "int64"), + ((128,), "int64"), + ((128,), "int64"), + ((128,), "float32"), + ] + + @I.ir_module + class Expected4D: + @R.function + def main( + data: R.Tensor((8, 16, 32, 64), dtype="float32"), + indices_0: R.Tensor((128,), dtype="int64"), + indices_1: R.Tensor((128,), dtype="int64"), + indices_2: R.Tensor((128,), dtype="int64"), + indices_3: R.Tensor((128,), dtype="int64"), + values: R.Tensor((128,), dtype="float32"), + ) -> R.Tensor((8, 16, 32, 64), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((8, 16, 32, 64), dtype="float32") = R.index_put( + data, + R.tuple(indices_0, indices_1, indices_2, indices_3), + values, + accumulate=False, + ) + gv: R.Tensor((8, 16, 32, 64), dtype="float32") = lv + R.output(gv) + return gv + + # Test case 5: 5D input + class IndexPut5D(Module): + def forward(self, data, indices_0, indices_1, indices_2, indices_3, indices_4, values): + indices_tuple = (indices_0, indices_1, indices_2, indices_3, indices_4) + return data.index_put_(indices_tuple, values, accumulate=False) + + input_info_5d = [ + ((4, 8, 16, 32, 64), "float32"), + ((128,), "int64"), + ((128,), "int64"), + ((128,), "int64"), + ((128,), "int64"), + ((128,), "int64"), + ((128,), "float32"), + ] + + @I.ir_module + class Expected5D: + @R.function + def main( + data: R.Tensor((4, 8, 16, 32, 64), dtype="float32"), + indices_0: R.Tensor((128,), dtype="int64"), + indices_1: R.Tensor((128,), dtype="int64"), + indices_2: R.Tensor((128,), dtype="int64"), + indices_3: R.Tensor((128,), dtype="int64"), + indices_4: R.Tensor((128,), dtype="int64"), + values: R.Tensor((128,), dtype="float32"), + ) -> R.Tensor((4, 8, 16, 32, 64), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((4, 8, 16, 32, 64), dtype="float32") = R.index_put( + data, + R.tuple(indices_0, indices_1, indices_2, indices_3, indices_4), + values, + accumulate=False, + ) + gv: R.Tensor((4, 8, 16, 32, 64), dtype="float32") = lv + R.output(gv) + return gv + + # Run verification for each case + verify_model(IndexPut1D(), input_info_1d, {}, Expected1D) + verify_model(IndexPut2D(), input_info_2d, {}, Expected2D) + verify_model(IndexPut3D(), input_info_3d, {}, Expected3D) + verify_model(IndexPut4D(), input_info_4d, {}, Expected4D) + verify_model(IndexPut5D(), input_info_5d, {}, Expected5D) + + def test_flip(): class Flip0(Module): def forward(self, data):