From c8545adb2cada41767396e264eae5c14572a8ad7 Mon Sep 17 00:00:00 2001 From: Pratheesh Date: Mon, 21 Apr 2025 04:25:08 +0000 Subject: [PATCH 01/16] add support for index_put_ op --- include/tvm/relax/attrs/manipulate.h | 14 ++ .../torch/base_fx_graph_translator.py | 17 +++ .../torch/exported_program_translator.py | 1 + .../tvm/relax/frontend/torch/fx_translator.py | 1 + python/tvm/relax/op/__init__.py | 1 + python/tvm/relax/op/manipulate.py | 53 +++++++- python/tvm/relax/op/op_attrs.py | 4 + .../transform/legalize_ops/manipulate.py | 16 +++ python/tvm/script/ir_builder/relax/ir.py | 2 + python/tvm/topi/__init__.py | 1 + python/tvm/topi/index_put.py | 119 +++++++++++++++++ src/relax/op/tensor/manipulate.cc | 122 ++++++++++++++++++ src/relax/op/tensor/manipulate.h | 13 ++ .../test_frontend_from_exported_program.py | 104 +++++++++++++++ tests/python/relax/test_frontend_from_fx.py | 83 ++++++++++++ 15 files changed, 550 insertions(+), 1 deletion(-) create mode 100644 python/tvm/topi/index_put.py diff --git a/include/tvm/relax/attrs/manipulate.h b/include/tvm/relax/attrs/manipulate.h index 67f99d9b417e..8eb0b087bdd0 100644 --- a/include/tvm/relax/attrs/manipulate.h +++ b/include/tvm/relax/attrs/manipulate.h @@ -182,6 +182,20 @@ struct GatherNDAttrs : public tvm::AttrsNode { } }; // struct GatherNDAttrs +/*! \brief Attributes used in index_put operator */ +struct IndexPutAttrs : public tvm::AttrsNode { + bool accumulate; + + TVM_DECLARE_ATTRS(IndexPutAttrs, "relax.attrs.IndexPutAttrs") { + TVM_ATTR_FIELD(accumulate) + .set_default(false) + .describe( + "Whether to accumulate (add) values rather than replace. " + "If true, performs tensor[indices] += values, " + "otherwise performs tensor[indices] = values."); + } +}; // struct IndexPutAttrs + /*! \brief Attributes used in scatter_elements operators */ struct ScatterElementsAttrs : public tvm::AttrsNode { Integer axis; diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 33f6ffc3132e..8e6f36f509a3 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1117,6 +1117,23 @@ def _gather(self, node: fx.Node) -> relax.Var: dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0) index = self.env[node.args[2]] return self.block_builder.emit(relax.op.gather_elements(x, index, axis=dim)) + + def _index_put_(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + tensor = args[0] + indices = args[1] if len(args) > 1 else node.kwargs.get("indices", ()) + values = args[2] if len(args) > 2 else node.kwargs.get("values") + accumulate = args[3] if len(args) > 3 else node.kwargs.get("accumulate", False) + + # Ensure accumulate is a boolean + if isinstance(accumulate, str): + accumulate = accumulate.lower() == "true" + elif not isinstance(accumulate, bool): + accumulate = bool(accumulate) + + if isinstance(indices, (list, tuple)): + indices = relax.Tuple(indices) if indices else relax.Tuple([]) + return self.block_builder.emit(relax.op.index_put(tensor, indices, values, accumulate)) def _index_tensor(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 0434712050ed..962ade7455f2 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -433,6 +433,7 @@ def create_convert_map( "flip.default": self._flip, "gather.default": self._gather, "index.Tensor": self._index_tensor, + "index_put_.default": self._index_put_, "narrow.default": self._narrow, "permute.default": self._permute, "repeat.default": self._repeat, diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 55abf20fcc03..dbe42c8f4f17 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -801,6 +801,7 @@ def create_convert_map( "flatten": self._flatten, "flip": self._flip, "gather": self._gather, + "index_put_": self._index_put_, "narrow": self._narrow, "numel": self._numel, "permute": self._permute, diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py index 097313a33da2..7b8c34b6415f 100644 --- a/python/tvm/relax/op/__init__.py +++ b/python/tvm/relax/op/__init__.py @@ -95,6 +95,7 @@ flip, gather_elements, gather_nd, + index_put, index_tensor, layout_transform, one_hot, diff --git a/python/tvm/relax/op/manipulate.py b/python/tvm/relax/op/manipulate.py index a693adf4325c..8a963013b299 100644 --- a/python/tvm/relax/op/manipulate.py +++ b/python/tvm/relax/op/manipulate.py @@ -562,7 +562,7 @@ def index_tensor(data: Expr, indices: Union[Expr, List[Expr]]) -> Expr: or a Python ``list`` / ``tuple`` that will be promoted to a tuple expression automatically. Each tensor must have an integer dtype. - + Returns ------- result : relax.Expr @@ -595,6 +595,57 @@ def index_tensor(data: Expr, indices: Union[Expr, List[Expr]]) -> Expr: return _ffi_api.index_tensor(data, indices) # type: ignore +def index_put( + data: Expr, + indices: Union[Expr, Tuple[Expr]], + values: Expr, + accumulate: bool = False, + reduction: str = "update" +) -> Expr: + """This operation updates values in `data` at positions + specified by `indices` with corresponding values from `values`. The `indices` is a tuple + of tensors where each tensor corresponds to a dimension in `data`. + When `accumulate` is True, the operation performs accumulation (addition) rather than + replacement. The `reduction` parameter allows specifying different reduction operations. + Parameters + ---------- + data : relax.Expr + The input tensor to be modified + indices : Union[Expr, Tuple[Expr]] + Tuple of index tensors (one for each dimension) specifying positions to update + values : relax.Expr + Values to place at the specified indices + accumulate : bool + Whether to accumulate (add) values rather than replace (default: False) + Returns + ------- + result : relax.Expr + A new tensor with the same shape as data but with specified positions updated + Examples + -------- + .. code-block:: python + # inputs + data = torch.zeros(3, 3) + indices = (torch.tensor([0, 2]), torch.tensor([1, 1])) + values = torch.tensor([1.0, 2.0]) + # output + output = [ + [0.0, 1.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 2.0, 0.0], + ] + # with accumulate=True + output = [ + [0.0, 1.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 3.0, 0.0], + ] + """ + if not isinstance(indices, (list, tuple)): + indices = RxTuple(indices) if indices else RxTuple([]) + return _ffi_api.index_put(data, indices, values, accumulate) # type: ignore + + def scatter_elements( data: Expr, indices: Expr, updates: Expr, axis: int = 0, reduction: str = "update" ): diff --git a/python/tvm/relax/op/op_attrs.py b/python/tvm/relax/op/op_attrs.py index fda4258a093b..0c955e9473e2 100644 --- a/python/tvm/relax/op/op_attrs.py +++ b/python/tvm/relax/op/op_attrs.py @@ -143,7 +143,11 @@ class SqueezeAttrs(Attrs): class StackAttrs(Attrs): """Attributes for concat operator""" +@tvm._ffi.register_object("relax.attrs.IndexPutAttrs") +class IndexPutAttrs(Attrs): + """Attributes for index_put operator""" + @tvm._ffi.register_object("relax.attrs.LayoutTransformAttrs") class LayoutTransformAttrs(Attrs): """Attributes used in layout_transform operator""" diff --git a/python/tvm/relax/transform/legalize_ops/manipulate.py b/python/tvm/relax/transform/legalize_ops/manipulate.py index 84baa887d9aa..3b9fff578ce3 100644 --- a/python/tvm/relax/transform/legalize_ops/manipulate.py +++ b/python/tvm/relax/transform/legalize_ops/manipulate.py @@ -193,6 +193,22 @@ def _index_tensor(bb: BlockBuilder, call: Call) -> Expr: return bb.call_te(topi.index_tensor, call.args[0], fields) +@register_legalize("relax.index_put") +def _index_put(bb: BlockBuilder, call: Call) -> Expr: + data = call.args[0] + indices = call.args[1] + values = call.args[2] + accumulate = call.attrs.accumulate + + # If indices is a Tuple, unpack it into individual tensors + if isinstance(indices, relax.Tuple): + indices_list = [indices.fields[i] for i in range(len(indices.fields))] + else: + indices_list = [indices] + + return bb.call_te(topi.index_put, data, indices_list, values, accumulate=accumulate,) + + @register_legalize("relax.scatter_elements") def _scatter_elements(bb: BlockBuilder, call: Call) -> Expr: return bb.call_te( diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index 22b00cd70416..d2952ed8e0d7 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -100,6 +100,7 @@ greater, greater_equal, hint_on_device, + index_put, image, index_tensor, invoke_closure, @@ -785,6 +786,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "greater_equal", "hexagon", "hint_on_device", + "index_put", "image", "index_tensor", "invoke_closure", diff --git a/python/tvm/topi/__init__.py b/python/tvm/topi/__init__.py index 1de6941c9923..fa4e98a89a42 100644 --- a/python/tvm/topi/__init__.py +++ b/python/tvm/topi/__init__.py @@ -33,6 +33,7 @@ from .math import * from .tensor import * from .generic_op_impl import * +from .index_put import * from .reduction import * from .transform import * from .broadcast import * diff --git a/python/tvm/topi/index_put.py b/python/tvm/topi/index_put.py new file mode 100644 index 000000000000..8b486d41a0e7 --- /dev/null +++ b/python/tvm/topi/index_put.py @@ -0,0 +1,119 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""IndexPut operator""" +from tvm import te +from tvm import tir +from . import utils +from .math import cast + + +def index_put(data, indices, values, accumulate=False): + """Put values into an array according to indices. + + Parameters + ---------- + data : tvm.te.Tensor + The source array to be modified. + + indices : Tuple[tvm.te.Tensor] + Tuple of 1D index tensors (one for each dimension) specifying positions. + + values : tvm.te.Tensor + The values to place at the specified indices. + + accumulate : bool, optional + Whether to accumulate (add) values rather than replace. + If True, performs tensor[indices] += values + If False, performs tensor[indices] = values + Default is False. + + Returns + ------- + ret : tvm.te.Tensor + """ + if not isinstance(indices, (list, tuple)): + indices = [indices] + + # Check indices match data dimensions + if len(indices) != len(data.shape): + raise ValueError( + f"Number of index tensors ({len(indices)}) must match " + f"data dimensions ({len(data.shape)})" + ) + + # Prepare ranges and strides + shape = data.shape + full_range = 1 + for dim in shape: + full_range *= dim + + # Check all indices have same length + index_len = indices[0].shape[0] + for idx in indices[1:]: + if not utils.equal_const_int(idx.shape[0], index_len): + raise ValueError("All index tensors must have same length") + + def gen_ir(data_ptr, index_ptrs, values_ptr, out_ptr, reduce_func): + ib = tir.ir_builder.create() + + data = ib.buffer_ptr(data_ptr) + indices = [ib.buffer_ptr(idx) for idx in index_ptrs] + values = ib.buffer_ptr(values_ptr) + out = ib.buffer_ptr(out_ptr) + + # Copy initial input data to output + with ib.for_range(0, full_range, "i", kind="parallel") as i: + out[i] = data[i] + + with ib.for_range(0, index_len, "k", kind="parallel") as k: + # Calculate multi-dimensional index + flat_index = 0 + stride = 1 + for dim in range(len(shape)-1, -1, -1): + # Get index and shift to positive if needed + idx_val = indices[dim][k] + shifted_idx = idx_val + (idx_val < 0) * shape[dim] + flat_index += shifted_idx * stride + stride *= shape[dim] + + reduce_func(out, flat_index, values[k]) + + return ib.get() + + def update_func(dst_ptr, dst_index, update): + dst_ptr[dst_index] = update + + def add_func(dst_ptr, dst_index, update): + dst_ptr[dst_index] += update + + reduce_func = add_func if accumulate else update_func + + # Prepare input buffers + in_buffers = [data] + in_buffers.extend(indices) + in_buffers.append(values) + + out_buf = tir.decl_buffer(data.shape, data.dtype, "out_buf") + return te.extern( + [data.shape], + in_buffers, + lambda ins, outs: gen_ir(ins[0], ins[1:-1], ins[-1], outs[0], reduce_func), + dtype=data.dtype, + out_buffers=[out_buf], + name="index_put.generic", + tag="index_put.generic", + ) \ No newline at end of file diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index f56135a35bc3..09c67b2fc833 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -1972,6 +1972,128 @@ TVM_REGISTER_OP("relax.gather_nd") .set_attr("FInferStructInfo", InferStructInfoGatherND) .set_attr("FPurity", Bool(true)); +/* relax.index_put */ +TVM_REGISTER_NODE_TYPE(IndexPutAttrs); + +Expr index_put(Expr data, Expr indices, Expr values, bool accumulate) { + auto attrs = make_object(); + attrs->accumulate = std::move(accumulate); + static const Op& op = Op::Get("relax.index_put"); + return Call(op, {data, indices, values}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.index_put").set_body_typed(index_put); + +StructInfo InferStructInfoIndexPut(const Call& call, const BlockBuilder& ctx) { + const auto* data_sinfo = GetStructInfoAs(call->args[0]); + const auto* values_sinfo = GetStructInfoAs(call->args[2]); + + auto diag_def = [&](const TensorStructInfoNode* sinfo, String name, String type_key) { + if (sinfo == nullptr) { + ctx->ReportFatal(Diagnostic::Error(call) + << "IndexPut requires the input " << name + << " to be a Tensor. However, the given one is " << type_key); + } + }; + + diag_def(data_sinfo, "data", call->args[0]->struct_info_->GetTypeKey()); + diag_def(values_sinfo, "values", call->args[2]->struct_info_->GetTypeKey()); + + // Handle indices: either a single tensor or a tuple of tensors + Array indices_tensors; + + if (const auto* tuple_sinfo = GetStructInfoAs(call->args[1])) { + // Indices is a tuple of tensors + for (size_t i = 0; i < tuple_sinfo->fields.size(); ++i) { + const auto* tensor_sinfo = tuple_sinfo->fields[i].as(); + if (tensor_sinfo == nullptr) { + ctx->ReportFatal(Diagnostic::Error(call) + << "IndexPut requires each index in the indices tuple to be a Tensor. " + << "However, element " << i << " is " + << tuple_sinfo->fields[i]->GetTypeKey()); + } + indices_tensors.push_back(GetRef(tensor_sinfo)); + } + } else if (const auto* tensor_sinfo = GetStructInfoAs(call->args[1])) { + // Indices is a single tensor + indices_tensors.push_back(GetRef(tensor_sinfo)); + } else { + ctx->ReportFatal(Diagnostic::Error(call) + << "IndexPut requires indices to be a Tensor or a tuple of Tensors. " + << "However, the given one is " << call->args[1]->struct_info_->GetTypeKey()); + } + + if (data_sinfo->IsUnknownNdim()) { + return TensorStructInfo(data_sinfo->dtype, kUnknownNDim, data_sinfo->vdevice); + } + + // Validate each index tensor + for (size_t i = 0; i < indices_tensors.size(); ++i) { + const auto& tensor_sinfo = indices_tensors[i]; + if (!tensor_sinfo->IsUnknownNdim() && tensor_sinfo->ndim != 1) { + ctx->ReportFatal(Diagnostic::Error(call) + << "IndexPut requires each index tensor to be 1D. " + << "However, index tensor " << i << " has ndim=" << tensor_sinfo->ndim); + } + if (tensor_sinfo->IsUnknownDtype()) { + LOG(WARNING) << "Data type of index tensor " << i + << " has not been specified. Assume it has an integer type."; + } else if (!(tensor_sinfo->dtype.is_int() || tensor_sinfo->dtype.is_uint())) { + ctx->ReportFatal(Diagnostic::Error(call) + << "IndexPut requires each index tensor to have integer dtype. " + << "However, index tensor " << i << " has dtype=" << tensor_sinfo->dtype); + } + } + + // Check that the number of index tensors matches data dimensions + if (!data_sinfo->IsUnknownNdim() && indices_tensors.size() != static_cast(data_sinfo->ndim)) { + ctx->ReportFatal(Diagnostic::Error(call) + << "IndexPut requires the number of index tensors (" << indices_tensors.size() + << ") to match the data tensor dimensions (" << data_sinfo->ndim << ")"); + } + + // Check data and values dtype compatibility + if (data_sinfo->IsUnknownDtype() || values_sinfo->IsUnknownDtype()) { + auto diag_dtype = [&](const TensorStructInfoNode* sinfo, String name) { + if (sinfo->IsUnknownDtype()) { + LOG(WARNING) << "Data type of " << name + << " has not been specified. Assume it has an integer type."; + } + }; + diag_dtype(data_sinfo, "data"); + diag_dtype(values_sinfo, "values"); + } else if (data_sinfo->dtype != values_sinfo->dtype) { + ctx->ReportFatal(Diagnostic::Error(call) + << "IndexPut requires the input data to have the same type as values. " + << "However, the given types are data: " << data_sinfo->dtype + << ", values: " << values_sinfo->dtype); + } + + // Check values shape compatibility + const auto* values_shape = values_sinfo->shape.as(); + if (values_shape) { + if (values_sinfo->ndim != 1) { + LOG(WARNING) << "IndexPut typically expects values to be 1D, but got ndim=" + << values_sinfo->ndim; + } + } + + const auto* data_shape = data_sinfo->shape.as(); + if (data_shape) { + return TensorStructInfo(ShapeExpr(data_shape->values), data_sinfo->dtype, data_sinfo->vdevice); + } + return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim, data_sinfo->vdevice); +} + +TVM_REGISTER_OP("relax.index_put") + .set_attrs_type() + .set_num_inputs(3) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("indices", "Tensor", "The indices tensor(s).") + .add_argument("values", "Tensor", "The values to put.") + .set_attr("FInferStructInfo", InferStructInfoIndexPut) + .set_attr("FPurity", Bool(true)); + /* relax.scatter_elements */ TVM_REGISTER_NODE_TYPE(ScatterElementsAttrs); diff --git a/src/relax/op/tensor/manipulate.h b/src/relax/op/tensor/manipulate.h index 4580f9191bb8..2e4c92c15052 100644 --- a/src/relax/op/tensor/manipulate.h +++ b/src/relax/op/tensor/manipulate.h @@ -218,6 +218,19 @@ Expr gather_nd(Expr data, Expr indices, int batch_dims = 0); */ Expr index_tensor(Expr data, Expr indices); +/*! + * \brief Put values into an array according to indices. + * \param data The input tensor to be modified. + * \param indices The index positions where values should be placed. + * This should be a tuple of 1D tensors (one for each dimension). + * \param values The values to place at the specified indices. + * \param accumulate Whether to accumulate (add) values rather than replace. + * If true, equivalent to tensor[indices] += values. + * If false, equivalent to tensor[indices] = values. + * \return The computed result with values placed at specified indices. + */ +Expr index_put(Expr data, Expr indices, Expr values, bool accumulate = false); + /*! * \brief Scatter updates into an array according to indices. * \param data The input tensor. diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 108617991b1f..8beb64b9d126 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -4169,6 +4169,110 @@ def main( verify_model(Gather3(), example_args, {}, Expected3) +def test_index_put(): + # Test case 1: 2D input + class IndexPut2D(Module): + def forward(self, data, indices_0, indices_1, values): + indices_tuple = (indices_0, indices_1) + return data.index_put_(indices_tuple, values, accumulate=False) + + # Test case 2: 3D input + class IndexPut3D(Module): + def forward(self, data, indices_0, indices_1, indices_2, values): + indices_tuple = (indices_0, indices_1, indices_2) + return data.index_put_(indices_tuple, values, accumulate=False) + + # Test case 3: 4D input + class IndexPut4D(Module): + def forward(self, data, indices_0, indices_1, indices_2, indices_3, values): + indices_tuple = (indices_0, indices_1, indices_2, indices_3) + return data.index_put_(indices_tuple, values, accumulate=False) + + @I.ir_module + class Expected2D: + @R.function + def main( + data: R.Tensor((4, 2), dtype="float32"), + indices_0: R.Tensor((2,), dtype="int64"), + indices_1: R.Tensor((2,), dtype="int64"), + values: R.Tensor((2,), dtype="float32"), + ) -> R.Tuple(R.Tensor((4, 2), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((4, 2), dtype="float32") = R.index_put( + data, R.tuple(indices_0, indices_1), values, accumulate=False + ) + gv: R.Tuple(R.Tensor((4, 2), dtype="float32")) = (lv,) + R.output(gv) + return gv + + @I.ir_module + class Expected3D: + @R.function + def main( + data: R.Tensor((4, 2, 3), dtype="float32"), + indices_0: R.Tensor((2,), dtype="int64"), + indices_1: R.Tensor((2,), dtype="int64"), + indices_2: R.Tensor((2,), dtype="int64"), + values: R.Tensor((2,), dtype="float32"), + ) -> R.Tuple(R.Tensor((4, 2, 3), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((4, 2, 3), dtype="float32") = R.index_put( + data, R.tuple(indices_0, indices_1, indices_2), values, accumulate=False + ) + gv: R.Tuple(R.Tensor((4, 2, 3), dtype="float32")) = (lv,) + R.output(gv) + return gv + + @I.ir_module + class Expected4D: + @R.function + def main( + data: R.Tensor((4, 2, 3, 2), dtype="float32"), + indices_0: R.Tensor((2,), dtype="int64"), + indices_1: R.Tensor((2,), dtype="int64"), + indices_2: R.Tensor((2,), dtype="int64"), + indices_3: R.Tensor((2,), dtype="int64"), + values: R.Tensor((2,), dtype="float32"), + ) -> R.Tuple(R.Tensor((4, 2, 3, 2), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((4, 2, 3, 2), dtype="float32") = R.index_put( + data, R.tuple(indices_0, indices_1, indices_2, indices_3), values, accumulate=False + ) + gv: R.Tuple(R.Tensor((4, 2, 3, 2), dtype="float32")) = (lv,) + R.output(gv) + return gv + + # Example arguments for tracing + example_args_2d = ( + torch.randn(4, 2, dtype=torch.float32), + torch.randint(0, 4, (2,), dtype=torch.int64), + torch.randint(0, 2, (2,), dtype=torch.int64), + torch.randn(2, dtype=torch.float32), + ) + + example_args_3d = ( + torch.randn(4, 2, 3, dtype=torch.float32), + torch.randint(0, 4, (2,), dtype=torch.int64), + torch.randint(0, 2, (2,), dtype=torch.int64), + torch.randint(0, 3, (2,), dtype=torch.int64), + torch.randn(2, dtype=torch.float32), + ) + + example_args_4d = ( + torch.randn(4, 2, 3, 2, dtype=torch.float32), + torch.randint(0, 4, (2,), dtype=torch.int64), + torch.randint(0, 2, (2,), dtype=torch.int64), + torch.randint(0, 3, (2,), dtype=torch.int64), + torch.randint(0, 2, (2,), dtype=torch.int64), + torch.randn(2, dtype=torch.float32), + ) + + # Run verification for each case + verify_model(IndexPut2D(), example_args_2d, {}, Expected2D) + verify_model(IndexPut3D(), example_args_3d, {}, Expected3D) + verify_model(IndexPut4D(), example_args_4d, {}, Expected4D) + + def test_flip(): class Flip0(Module): def forward(self, data): diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index cb69398e0a00..c5870d455e2f 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -4404,6 +4404,89 @@ def main( verify_model(Gather3(), [([2, 3], "float32"), ([2, 3], "int32")], {}, Expected3) +def test_index_put(): + # Test case 1: 2D input + class IndexPut2D(Module): + def forward(self, data, indices_0, indices_1, values): + indices_tuple = (indices_0, indices_1) + return data.index_put_(indices_tuple, values, accumulate=False) + + # Test case 2: 3D input + class IndexPut3D(Module): + def forward(self, data, indices_0, indices_1, indices_2, values): + indices_tuple = (indices_0, indices_1, indices_2) + return data.index_put_(indices_tuple, values, accumulate=False) + + # Test case 3: 4D input + class IndexPut4D(Module): + def forward(self, data, indices_0, indices_1, indices_2, indices_3, values): + indices_tuple = (indices_0, indices_1, indices_2, indices_3) + return data.index_put_(indices_tuple, values, accumulate=False) + + @I.ir_module + class Expected2D: + @R.function + def main( + data: R.Tensor((4, 2), dtype="float32"), + indices_0: R.Tensor((2,), dtype="int64"), + indices_1: R.Tensor((2,), dtype="int64"), + values: R.Tensor((2,), "float32"), + ) -> R.Tensor((4, 2), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((4, 2), dtype="float32") = R.index_put( + data, R.tuple(indices_0, indices_1), values, accumulate=False + ) + gv: R.Tensor((4, 2), dtype="float32") = lv + R.output(gv) + return gv + + @I.ir_module + class Expected3D: + @R.function + def main( + data: R.Tensor((4, 2, 3), dtype="float32"), + indices_0: R.Tensor((2,), dtype="int64"), + indices_1: R.Tensor((2,), dtype="int64"), + indices_2: R.Tensor((2,), dtype="int64"), + values: R.Tensor((2,), "float32"), + ) -> R.Tensor((4, 2, 3), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((4, 2, 3), dtype="float32") = R.index_put( + data, R.tuple(indices_0, indices_1, indices_2), values, accumulate=False + ) + gv: R.Tensor((4, 2, 3), dtype="float32") = lv + R.output(gv) + return gv + + @I.ir_module + class Expected4D: + @R.function + def main( + data: R.Tensor((4, 2, 3, 2), dtype="float32"), + indices_0: R.Tensor((2,), dtype="int64"), + indices_1: R.Tensor((2,), dtype="int64"), + indices_2: R.Tensor((2,), dtype="int64"), + indices_3: R.Tensor((2,), dtype="int64"), + values: R.Tensor((2,), "float32"), + ) -> R.Tensor((4, 2, 3, 2), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((4, 2, 3, 2), dtype="float32") = R.index_put( + data, R.tuple(indices_0, indices_1, indices_2, indices_3), values, accumulate=False + ) + gv: R.Tensor((4, 2, 3, 2), dtype="float32") = lv + R.output(gv) + return gv + + input_info_2d = [((4, 2), "float32"), ((2,), "int64"), ((2,), "int64"), ((2,), "float32")] + input_info_3d = [((4, 2, 3), "float32"), ((2,), "int64"), ((2,), "int64"), ((2,), "int64"), ((2,), "float32")] + input_info_4d = [((4, 2, 3, 2), "float32"), ((2,), "int64"), ((2,), "int64"), ((2,), "int64"), ((2,), "int64"), ((2,), "float32")] + + # Run verification for each case + verify_model(IndexPut2D(), input_info_2d, {}, Expected2D) + verify_model(IndexPut3D(), input_info_3d, {}, Expected3D) + verify_model(IndexPut4D(), input_info_4d, {}, Expected4D) + + def test_flip(): class Flip0(Module): def forward(self, data): From dc7bf63c74f0487561c187dd6de8a7204da781f8 Mon Sep 17 00:00:00 2001 From: Pratheesh Date: Mon, 21 Apr 2025 05:39:18 +0000 Subject: [PATCH 02/16] add test cases from 1D to 5D in both exported program and fx graph --- .../test_frontend_from_exported_program.py | 181 ++++++++++++------ tests/python/relax/test_frontend_from_fx.py | 136 +++++++++---- 2 files changed, 219 insertions(+), 98 deletions(-) diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 8beb64b9d126..7eb522bb6e62 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -4170,109 +4170,174 @@ def main( def test_index_put(): - # Test case 1: 2D input + # Test case 1: 1D input + class IndexPut1D(Module): + def forward(self, data, indices_0, values): + indices_tuple = (indices_0,) + return data.index_put_(indices_tuple, values, accumulate=False) + + example_args_1d = ( + torch.randn(64, dtype=torch.float32), + torch.randint(0, 64, (128,), dtype=torch.int64), + torch.randn(128, dtype=torch.float32), + ) + + @I.ir_module + class Expected1D: + @R.function + def main( + data: R.Tensor((64,), dtype="float32"), + indices_0: R.Tensor((128,), dtype="int64"), + values: R.Tensor((128,), dtype="float32"), + ) -> R.Tuple(R.Tensor((64,), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((64,), dtype="float32") = R.index_put( + data, R.tuple(indices_0), values, accumulate=False + ) + gv: R.Tuple(R.Tensor((64,), dtype="float32")) = (lv,) + R.output(gv) + return gv + + # Test case 2: 2D input class IndexPut2D(Module): def forward(self, data, indices_0, indices_1, values): indices_tuple = (indices_0, indices_1) return data.index_put_(indices_tuple, values, accumulate=False) - # Test case 2: 3D input - class IndexPut3D(Module): - def forward(self, data, indices_0, indices_1, indices_2, values): - indices_tuple = (indices_0, indices_1, indices_2) - return data.index_put_(indices_tuple, values, accumulate=False) - - # Test case 3: 4D input - class IndexPut4D(Module): - def forward(self, data, indices_0, indices_1, indices_2, indices_3, values): - indices_tuple = (indices_0, indices_1, indices_2, indices_3) - return data.index_put_(indices_tuple, values, accumulate=False) + example_args_2d = ( + torch.randn(32, 64, dtype=torch.float32), + torch.randint(0, 32, (128,), dtype=torch.int64), + torch.randint(0, 64, (128,), dtype=torch.int64), + torch.randn(128, dtype=torch.float32), + ) @I.ir_module class Expected2D: @R.function def main( - data: R.Tensor((4, 2), dtype="float32"), - indices_0: R.Tensor((2,), dtype="int64"), - indices_1: R.Tensor((2,), dtype="int64"), - values: R.Tensor((2,), dtype="float32"), - ) -> R.Tuple(R.Tensor((4, 2), dtype="float32")): + data: R.Tensor((32, 64), dtype="float32"), + indices_0: R.Tensor((128,), dtype="int64"), + indices_1: R.Tensor((128,), dtype="int64"), + values: R.Tensor((128,), dtype="float32"), + ) -> R.Tuple(R.Tensor((32, 64), dtype="float32")): with R.dataflow(): - lv: R.Tensor((4, 2), dtype="float32") = R.index_put( + lv: R.Tensor((32, 64), dtype="float32") = R.index_put( data, R.tuple(indices_0, indices_1), values, accumulate=False ) - gv: R.Tuple(R.Tensor((4, 2), dtype="float32")) = (lv,) + gv: R.Tuple(R.Tensor((32, 64), dtype="float32")) = (lv,) R.output(gv) return gv + # Test case 3: 3D input + class IndexPut3D(Module): + def forward(self, data, indices_0, indices_1, indices_2, values): + indices_tuple = (indices_0, indices_1, indices_2) + return data.index_put_(indices_tuple, values, accumulate=False) + + example_args_3d = ( + torch.randn(16, 32, 64, dtype=torch.float32), + torch.randint(0, 16, (128,), dtype=torch.int64), + torch.randint(0, 32, (128,), dtype=torch.int64), + torch.randint(0, 64, (128,), dtype=torch.int64), + torch.randn(128, dtype=torch.float32), + ) + @I.ir_module class Expected3D: @R.function def main( - data: R.Tensor((4, 2, 3), dtype="float32"), - indices_0: R.Tensor((2,), dtype="int64"), - indices_1: R.Tensor((2,), dtype="int64"), - indices_2: R.Tensor((2,), dtype="int64"), - values: R.Tensor((2,), dtype="float32"), - ) -> R.Tuple(R.Tensor((4, 2, 3), dtype="float32")): + data: R.Tensor((16, 32, 64), dtype="float32"), + indices_0: R.Tensor((128,), dtype="int64"), + indices_1: R.Tensor((128,), dtype="int64"), + indices_2: R.Tensor((128,), dtype="int64"), + values: R.Tensor((128,), dtype="float32"), + ) -> R.Tuple(R.Tensor((16, 32, 64), dtype="float32")): with R.dataflow(): - lv: R.Tensor((4, 2, 3), dtype="float32") = R.index_put( + lv: R.Tensor((16, 32, 64), dtype="float32") = R.index_put( data, R.tuple(indices_0, indices_1, indices_2), values, accumulate=False ) - gv: R.Tuple(R.Tensor((4, 2, 3), dtype="float32")) = (lv,) + gv: R.Tuple(R.Tensor((16, 32, 64), dtype="float32")) = (lv,) R.output(gv) return gv + # Test case 4: 4D input + class IndexPut4D(Module): + def forward(self, data, indices_0, indices_1, indices_2, indices_3, values): + indices_tuple = (indices_0, indices_1, indices_2, indices_3) + return data.index_put_(indices_tuple, values, accumulate=False) + + example_args_4d = ( + torch.randn(8, 16, 32, 64, dtype=torch.float32), + torch.randint(0, 8, (128,), dtype=torch.int64), + torch.randint(0, 16, (128,), dtype=torch.int64), + torch.randint(0, 32, (128,), dtype=torch.int64), + torch.randint(0, 64, (128,), dtype=torch.int64), + torch.randn(128, dtype=torch.float32), + ) + @I.ir_module class Expected4D: @R.function def main( - data: R.Tensor((4, 2, 3, 2), dtype="float32"), - indices_0: R.Tensor((2,), dtype="int64"), - indices_1: R.Tensor((2,), dtype="int64"), - indices_2: R.Tensor((2,), dtype="int64"), - indices_3: R.Tensor((2,), dtype="int64"), - values: R.Tensor((2,), dtype="float32"), - ) -> R.Tuple(R.Tensor((4, 2, 3, 2), dtype="float32")): + data: R.Tensor((8, 16, 32, 64), dtype="float32"), + indices_0: R.Tensor((128,), dtype="int64"), + indices_1: R.Tensor((128,), dtype="int64"), + indices_2: R.Tensor((128,), dtype="int64"), + indices_3: R.Tensor((128,), dtype="int64"), + values: R.Tensor((128,), dtype="float32"), + ) -> R.Tuple(R.Tensor((8, 16, 32, 64), dtype="float32")): with R.dataflow(): - lv: R.Tensor((4, 2, 3, 2), dtype="float32") = R.index_put( + lv: R.Tensor((8, 16, 32, 64), dtype="float32") = R.index_put( data, R.tuple(indices_0, indices_1, indices_2, indices_3), values, accumulate=False ) - gv: R.Tuple(R.Tensor((4, 2, 3, 2), dtype="float32")) = (lv,) + gv: R.Tuple(R.Tensor((8, 16, 32, 64), dtype="float32")) = (lv,) R.output(gv) return gv - # Example arguments for tracing - example_args_2d = ( - torch.randn(4, 2, dtype=torch.float32), - torch.randint(0, 4, (2,), dtype=torch.int64), - torch.randint(0, 2, (2,), dtype=torch.int64), - torch.randn(2, dtype=torch.float32), - ) + # Test case 5: 5D input + class IndexPut5D(Module): + def forward(self, data, indices_0, indices_1, indices_2, indices_3, indices_4, values): + indices_tuple = (indices_0, indices_1, indices_2, indices_3, indices_4) + return data.index_put_(indices_tuple, values, accumulate=False) - example_args_3d = ( - torch.randn(4, 2, 3, dtype=torch.float32), - torch.randint(0, 4, (2,), dtype=torch.int64), - torch.randint(0, 2, (2,), dtype=torch.int64), - torch.randint(0, 3, (2,), dtype=torch.int64), - torch.randn(2, dtype=torch.float32), + example_args_5d = ( + torch.randn(4, 8, 16, 32, 64, dtype=torch.float32), + torch.randint(0, 4, (128,), dtype=torch.int64), + torch.randint(0, 8, (128,), dtype=torch.int64), + torch.randint(0, 16, (128,), dtype=torch.int64), + torch.randint(0, 32, (128,), dtype=torch.int64), + torch.randint(0, 64, (128,), dtype=torch.int64), + torch.randn(128, dtype=torch.float32), ) - example_args_4d = ( - torch.randn(4, 2, 3, 2, dtype=torch.float32), - torch.randint(0, 4, (2,), dtype=torch.int64), - torch.randint(0, 2, (2,), dtype=torch.int64), - torch.randint(0, 3, (2,), dtype=torch.int64), - torch.randint(0, 2, (2,), dtype=torch.int64), - torch.randn(2, dtype=torch.float32), - ) + @I.ir_module + class Expected5D: + @R.function + def main( + data: R.Tensor((4, 8, 16, 32, 64), dtype="float32"), + indices_0: R.Tensor((128,), dtype="int64"), + indices_1: R.Tensor((128,), dtype="int64"), + indices_2: R.Tensor((128,), dtype="int64"), + indices_3: R.Tensor((128,), dtype="int64"), + indices_4: R.Tensor((128,), dtype="int64"), + values: R.Tensor((128,), dtype="float32"), + ) -> R.Tuple(R.Tensor((4, 8, 16, 32, 64), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((4, 8, 16, 32, 64), dtype="float32") = R.index_put( + data, R.tuple(indices_0, indices_1, indices_2, indices_3, indices_4), values, accumulate=False + ) + gv: R.Tuple(R.Tensor((4, 8, 16, 32, 64), dtype="float32")) = (lv,) + R.output(gv) + return gv # Run verification for each case + verify_model(IndexPut1D(), example_args_1d, {}, Expected1D) verify_model(IndexPut2D(), example_args_2d, {}, Expected2D) verify_model(IndexPut3D(), example_args_3d, {}, Expected3D) verify_model(IndexPut4D(), example_args_4d, {}, Expected4D) + verify_model(IndexPut5D(), example_args_5d, {}, Expected5D) + - def test_flip(): class Flip0(Module): def forward(self, data): diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index c5870d455e2f..15b4160e4d18 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -4405,86 +4405,142 @@ def main( def test_index_put(): - # Test case 1: 2D input + # Test case 1: 1D input + class IndexPut1D(Module): + def forward(self, data, indices_0, values): + indices_tuple = (indices_0,) + return data.index_put_(indices_tuple, values, accumulate=False) + + input_info_1d = [((64,), "float32"), ((128,), "int64"), ((128,), "float32")] + + @I.ir_module + class Expected1D: + @R.function + def main( + data: R.Tensor((64,), dtype="float32"), + indices_0: R.Tensor((128,), dtype="int64"), + values: R.Tensor((128,), dtype="float32"), + ) -> R.Tensor((64,), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((64,), dtype="float32") = R.index_put( + data, R.tuple(indices_0), values, accumulate=False + ) + gv: R.Tensor((64,), dtype="float32") = lv + R.output(gv) + return gv + + # Test case 2: 2D input class IndexPut2D(Module): def forward(self, data, indices_0, indices_1, values): indices_tuple = (indices_0, indices_1) return data.index_put_(indices_tuple, values, accumulate=False) - # Test case 2: 3D input - class IndexPut3D(Module): - def forward(self, data, indices_0, indices_1, indices_2, values): - indices_tuple = (indices_0, indices_1, indices_2) - return data.index_put_(indices_tuple, values, accumulate=False) + input_info_2d = [((32, 64), "float32"), ((128,), "int64"), ((128,), "int64"), ((128,), "float32")] - # Test case 3: 4D input - class IndexPut4D(Module): - def forward(self, data, indices_0, indices_1, indices_2, indices_3, values): - indices_tuple = (indices_0, indices_1, indices_2, indices_3) - return data.index_put_(indices_tuple, values, accumulate=False) - @I.ir_module class Expected2D: @R.function def main( - data: R.Tensor((4, 2), dtype="float32"), - indices_0: R.Tensor((2,), dtype="int64"), - indices_1: R.Tensor((2,), dtype="int64"), - values: R.Tensor((2,), "float32"), - ) -> R.Tensor((4, 2), dtype="float32"): + data: R.Tensor((32, 64), dtype="float32"), + indices_0: R.Tensor((128,), dtype="int64"), + indices_1: R.Tensor((128,), dtype="int64"), + values: R.Tensor((128,), dtype="float32"), + ) -> R.Tensor((32, 64), dtype="float32"): with R.dataflow(): - lv: R.Tensor((4, 2), dtype="float32") = R.index_put( + lv: R.Tensor((32, 64), dtype="float32") = R.index_put( data, R.tuple(indices_0, indices_1), values, accumulate=False ) - gv: R.Tensor((4, 2), dtype="float32") = lv + gv: R.Tensor((32, 64), dtype="float32") = lv R.output(gv) return gv - + + # Test case 3: 3D input + class IndexPut3D(Module): + def forward(self, data, indices_0, indices_1, indices_2, values): + indices_tuple = (indices_0, indices_1, indices_2) + return data.index_put_(indices_tuple, values, accumulate=False) + + input_info_3d = [((16, 32, 64), "float32"), ((128,), "int64"), ((128,), "int64"), ((128,), "int64"), ((128,), "float32")] + @I.ir_module class Expected3D: @R.function def main( - data: R.Tensor((4, 2, 3), dtype="float32"), - indices_0: R.Tensor((2,), dtype="int64"), - indices_1: R.Tensor((2,), dtype="int64"), - indices_2: R.Tensor((2,), dtype="int64"), - values: R.Tensor((2,), "float32"), - ) -> R.Tensor((4, 2, 3), dtype="float32"): + data: R.Tensor((16, 32, 64), dtype="float32"), + indices_0: R.Tensor((128,), dtype="int64"), + indices_1: R.Tensor((128,), dtype="int64"), + indices_2: R.Tensor((128,), dtype="int64"), + values: R.Tensor((128,), dtype="float32"), + ) -> R.Tensor((16, 32, 64), dtype="float32"): with R.dataflow(): - lv: R.Tensor((4, 2, 3), dtype="float32") = R.index_put( + lv: R.Tensor((16, 32, 64), dtype="float32") = R.index_put( data, R.tuple(indices_0, indices_1, indices_2), values, accumulate=False ) - gv: R.Tensor((4, 2, 3), dtype="float32") = lv + gv: R.Tensor((16, 32, 64), dtype="float32") = lv R.output(gv) return gv + # Test case 4: 4D input + class IndexPut4D(Module): + def forward(self, data, indices_0, indices_1, indices_2, indices_3, values): + indices_tuple = (indices_0, indices_1, indices_2, indices_3) + return data.index_put_(indices_tuple, values, accumulate=False) + + input_info_4d = [((8, 16, 32, 64), "float32"), ((128,), "int64"), ((128,), "int64"), ((128,), "int64"), ((128,), "int64"), ((128,), "float32")] + @I.ir_module class Expected4D: @R.function def main( - data: R.Tensor((4, 2, 3, 2), dtype="float32"), - indices_0: R.Tensor((2,), dtype="int64"), - indices_1: R.Tensor((2,), dtype="int64"), - indices_2: R.Tensor((2,), dtype="int64"), - indices_3: R.Tensor((2,), dtype="int64"), - values: R.Tensor((2,), "float32"), - ) -> R.Tensor((4, 2, 3, 2), dtype="float32"): + data: R.Tensor((8, 16, 32, 64), dtype="float32"), + indices_0: R.Tensor((128,), dtype="int64"), + indices_1: R.Tensor((128,), dtype="int64"), + indices_2: R.Tensor((128,), dtype="int64"), + indices_3: R.Tensor((128,), dtype="int64"), + values: R.Tensor((128,), dtype="float32"), + ) -> R.Tensor((8, 16, 32, 64), dtype="float32"): with R.dataflow(): - lv: R.Tensor((4, 2, 3, 2), dtype="float32") = R.index_put( + lv: R.Tensor((8, 16, 32, 64), dtype="float32") = R.index_put( data, R.tuple(indices_0, indices_1, indices_2, indices_3), values, accumulate=False ) - gv: R.Tensor((4, 2, 3, 2), dtype="float32") = lv + gv: R.Tensor((8, 16, 32, 64), dtype="float32") = lv R.output(gv) return gv - input_info_2d = [((4, 2), "float32"), ((2,), "int64"), ((2,), "int64"), ((2,), "float32")] - input_info_3d = [((4, 2, 3), "float32"), ((2,), "int64"), ((2,), "int64"), ((2,), "int64"), ((2,), "float32")] - input_info_4d = [((4, 2, 3, 2), "float32"), ((2,), "int64"), ((2,), "int64"), ((2,), "int64"), ((2,), "int64"), ((2,), "float32")] + # Test case 5: 5D input + class IndexPut5D(Module): + def forward(self, data, indices_0, indices_1, indices_2, indices_3, indices_4, values): + indices_tuple = (indices_0, indices_1, indices_2, indices_3, indices_4) + return data.index_put_(indices_tuple, values, accumulate=False) + + input_info_5d = [((4, 8, 16, 32, 64), "float32"), ((128,), "int64"), ((128,), "int64"), ((128,), "int64"), ((128,), "int64"), ((128,), "int64"), ((128,), "float32")] + + @I.ir_module + class Expected5D: + @R.function + def main( + data: R.Tensor((4, 8, 16, 32, 64), dtype="float32"), + indices_0: R.Tensor((128,), dtype="int64"), + indices_1: R.Tensor((128,), dtype="int64"), + indices_2: R.Tensor((128,), dtype="int64"), + indices_3: R.Tensor((128,), dtype="int64"), + indices_4: R.Tensor((128,), dtype="int64"), + values: R.Tensor((128,), dtype="float32"), + ) -> R.Tensor((4, 8, 16, 32, 64), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((4, 8, 16, 32, 64), dtype="float32") = R.index_put( + data, R.tuple(indices_0, indices_1, indices_2, indices_3, indices_4), values, accumulate=False + ) + gv: R.Tensor((4, 8, 16, 32, 64), dtype="float32") = lv + R.output(gv) + return gv # Run verification for each case + verify_model(IndexPut1D(), input_info_1d, {}, Expected1D) verify_model(IndexPut2D(), input_info_2d, {}, Expected2D) verify_model(IndexPut3D(), input_info_3d, {}, Expected3D) verify_model(IndexPut4D(), input_info_4d, {}, Expected4D) + verify_model(IndexPut5D(), input_info_5d, {}, Expected5D) def test_flip(): From ba84f0f64438201cbd08f8446b0b447837715965 Mon Sep 17 00:00:00 2001 From: Pratheesh Date: Mon, 21 Apr 2025 08:25:42 +0000 Subject: [PATCH 03/16] lint issues --- .../torch/base_fx_graph_translator.py | 4 +- python/tvm/relax/op/manipulate.py | 2 +- python/tvm/relax/op/op_attrs.py | 3 +- .../transform/legalize_ops/manipulate.py | 8 +++- python/tvm/topi/index_put.py | 4 +- .../test_frontend_from_exported_program.py | 10 ++++- tests/python/relax/test_frontend_from_fx.py | 44 ++++++++++++++++--- 7 files changed, 60 insertions(+), 15 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 8e6f36f509a3..20b66fdd4bfa 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1117,7 +1117,7 @@ def _gather(self, node: fx.Node) -> relax.Var: dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0) index = self.env[node.args[2]] return self.block_builder.emit(relax.op.gather_elements(x, index, axis=dim)) - + def _index_put_(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) tensor = args[0] @@ -1130,7 +1130,7 @@ def _index_put_(self, node: fx.Node) -> relax.Var: accumulate = accumulate.lower() == "true" elif not isinstance(accumulate, bool): accumulate = bool(accumulate) - + if isinstance(indices, (list, tuple)): indices = relax.Tuple(indices) if indices else relax.Tuple([]) return self.block_builder.emit(relax.op.index_put(tensor, indices, values, accumulate)) diff --git a/python/tvm/relax/op/manipulate.py b/python/tvm/relax/op/manipulate.py index 8a963013b299..5183f281a598 100644 --- a/python/tvm/relax/op/manipulate.py +++ b/python/tvm/relax/op/manipulate.py @@ -600,7 +600,6 @@ def index_put( indices: Union[Expr, Tuple[Expr]], values: Expr, accumulate: bool = False, - reduction: str = "update" ) -> Expr: """This operation updates values in `data` at positions specified by `indices` with corresponding values from `values`. The `indices` is a tuple @@ -617,6 +616,7 @@ def index_put( Values to place at the specified indices accumulate : bool Whether to accumulate (add) values rather than replace (default: False) + Returns ------- result : relax.Expr diff --git a/python/tvm/relax/op/op_attrs.py b/python/tvm/relax/op/op_attrs.py index 0c955e9473e2..fe527e38e8a8 100644 --- a/python/tvm/relax/op/op_attrs.py +++ b/python/tvm/relax/op/op_attrs.py @@ -143,11 +143,12 @@ class SqueezeAttrs(Attrs): class StackAttrs(Attrs): """Attributes for concat operator""" + @tvm._ffi.register_object("relax.attrs.IndexPutAttrs") class IndexPutAttrs(Attrs): """Attributes for index_put operator""" - + @tvm._ffi.register_object("relax.attrs.LayoutTransformAttrs") class LayoutTransformAttrs(Attrs): """Attributes used in layout_transform operator""" diff --git a/python/tvm/relax/transform/legalize_ops/manipulate.py b/python/tvm/relax/transform/legalize_ops/manipulate.py index 3b9fff578ce3..86254d95c390 100644 --- a/python/tvm/relax/transform/legalize_ops/manipulate.py +++ b/python/tvm/relax/transform/legalize_ops/manipulate.py @@ -206,7 +206,13 @@ def _index_put(bb: BlockBuilder, call: Call) -> Expr: else: indices_list = [indices] - return bb.call_te(topi.index_put, data, indices_list, values, accumulate=accumulate,) + return bb.call_te( + topi.index_put, + data, + indices_list, + values, + accumulate=accumulate, + ) @register_legalize("relax.scatter_elements") diff --git a/python/tvm/topi/index_put.py b/python/tvm/topi/index_put.py index 8b486d41a0e7..3a439bad5c4d 100644 --- a/python/tvm/topi/index_put.py +++ b/python/tvm/topi/index_put.py @@ -83,13 +83,13 @@ def gen_ir(data_ptr, index_ptrs, values_ptr, out_ptr, reduce_func): # Calculate multi-dimensional index flat_index = 0 stride = 1 - for dim in range(len(shape)-1, -1, -1): + for dim in range(len(shape) - 1, -1, -1): # Get index and shift to positive if needed idx_val = indices[dim][k] shifted_idx = idx_val + (idx_val < 0) * shape[dim] flat_index += shifted_idx * stride stride *= shape[dim] - + reduce_func(out, flat_index, values[k]) return ib.get() diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 7eb522bb6e62..4566654449cc 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -4288,7 +4288,10 @@ def main( ) -> R.Tuple(R.Tensor((8, 16, 32, 64), dtype="float32")): with R.dataflow(): lv: R.Tensor((8, 16, 32, 64), dtype="float32") = R.index_put( - data, R.tuple(indices_0, indices_1, indices_2, indices_3), values, accumulate=False + data, + R.tuple(indices_0, indices_1, indices_2, indices_3), + values, + accumulate=False, ) gv: R.Tuple(R.Tensor((8, 16, 32, 64), dtype="float32")) = (lv,) R.output(gv) @@ -4324,7 +4327,10 @@ def main( ) -> R.Tuple(R.Tensor((4, 8, 16, 32, 64), dtype="float32")): with R.dataflow(): lv: R.Tensor((4, 8, 16, 32, 64), dtype="float32") = R.index_put( - data, R.tuple(indices_0, indices_1, indices_2, indices_3, indices_4), values, accumulate=False + data, + R.tuple(indices_0, indices_1, indices_2, indices_3, indices_4), + values, + accumulate=False, ) gv: R.Tuple(R.Tensor((4, 8, 16, 32, 64), dtype="float32")) = (lv,) R.output(gv) diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 15b4160e4d18..4a21b8272adb 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -4435,7 +4435,12 @@ def forward(self, data, indices_0, indices_1, values): indices_tuple = (indices_0, indices_1) return data.index_put_(indices_tuple, values, accumulate=False) - input_info_2d = [((32, 64), "float32"), ((128,), "int64"), ((128,), "int64"), ((128,), "float32")] + input_info_2d = [ + ((32, 64), "float32"), + ((128,), "int64"), + ((128,), "int64"), + ((128,), "float32"), + ] @I.ir_module class Expected2D: @@ -4460,7 +4465,13 @@ def forward(self, data, indices_0, indices_1, indices_2, values): indices_tuple = (indices_0, indices_1, indices_2) return data.index_put_(indices_tuple, values, accumulate=False) - input_info_3d = [((16, 32, 64), "float32"), ((128,), "int64"), ((128,), "int64"), ((128,), "int64"), ((128,), "float32")] + input_info_3d = [ + ((16, 32, 64), "float32"), + ((128,), "int64"), + ((128,), "int64"), + ((128,), "int64"), + ((128,), "float32"), + ] @I.ir_module class Expected3D: @@ -4486,7 +4497,14 @@ def forward(self, data, indices_0, indices_1, indices_2, indices_3, values): indices_tuple = (indices_0, indices_1, indices_2, indices_3) return data.index_put_(indices_tuple, values, accumulate=False) - input_info_4d = [((8, 16, 32, 64), "float32"), ((128,), "int64"), ((128,), "int64"), ((128,), "int64"), ((128,), "int64"), ((128,), "float32")] + input_info_4d = [ + ((8, 16, 32, 64), "float32"), + ((128,), "int64"), + ((128,), "int64"), + ((128,), "int64"), + ((128,), "int64"), + ((128,), "float32"), + ] @I.ir_module class Expected4D: @@ -4501,7 +4519,10 @@ def main( ) -> R.Tensor((8, 16, 32, 64), dtype="float32"): with R.dataflow(): lv: R.Tensor((8, 16, 32, 64), dtype="float32") = R.index_put( - data, R.tuple(indices_0, indices_1, indices_2, indices_3), values, accumulate=False + data, + R.tuple(indices_0, indices_1, indices_2, indices_3), + values, + accumulate=False, ) gv: R.Tensor((8, 16, 32, 64), dtype="float32") = lv R.output(gv) @@ -4513,7 +4534,15 @@ def forward(self, data, indices_0, indices_1, indices_2, indices_3, indices_4, v indices_tuple = (indices_0, indices_1, indices_2, indices_3, indices_4) return data.index_put_(indices_tuple, values, accumulate=False) - input_info_5d = [((4, 8, 16, 32, 64), "float32"), ((128,), "int64"), ((128,), "int64"), ((128,), "int64"), ((128,), "int64"), ((128,), "int64"), ((128,), "float32")] + input_info_5d = [ + ((4, 8, 16, 32, 64), "float32"), + ((128,), "int64"), + ((128,), "int64"), + ((128,), "int64"), + ((128,), "int64"), + ((128,), "int64"), + ((128,), "float32") + ] @I.ir_module class Expected5D: @@ -4529,7 +4558,10 @@ def main( ) -> R.Tensor((4, 8, 16, 32, 64), dtype="float32"): with R.dataflow(): lv: R.Tensor((4, 8, 16, 32, 64), dtype="float32") = R.index_put( - data, R.tuple(indices_0, indices_1, indices_2, indices_3, indices_4), values, accumulate=False + data, + R.tuple(indices_0, indices_1, indices_2, indices_3, indices_4), + values, + accumulate=False, ) gv: R.Tensor((4, 8, 16, 32, 64), dtype="float32") = lv R.output(gv) From f663b1d01562914b1cc2cd4c392b2009bcc4cf90 Mon Sep 17 00:00:00 2001 From: Pratheesh Date: Mon, 21 Apr 2025 08:48:46 +0000 Subject: [PATCH 04/16] lint check --- python/tvm/relax/transform/legalize_ops/manipulate.py | 2 +- tests/python/relax/test_frontend_from_fx.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/relax/transform/legalize_ops/manipulate.py b/python/tvm/relax/transform/legalize_ops/manipulate.py index 86254d95c390..a66b60c0134b 100644 --- a/python/tvm/relax/transform/legalize_ops/manipulate.py +++ b/python/tvm/relax/transform/legalize_ops/manipulate.py @@ -209,7 +209,7 @@ def _index_put(bb: BlockBuilder, call: Call) -> Expr: return bb.call_te( topi.index_put, data, - indices_list, + indices_list, values, accumulate=accumulate, ) diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 4a21b8272adb..a77c5d82b58e 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -4541,7 +4541,7 @@ def forward(self, data, indices_0, indices_1, indices_2, indices_3, indices_4, v ((128,), "int64"), ((128,), "int64"), ((128,), "int64"), - ((128,), "float32") + ((128,), "float32"), ] @I.ir_module From 3a58f0fc68f3055d8055a144bcfc586d993e1e4f Mon Sep 17 00:00:00 2001 From: Pratheesh Date: Mon, 21 Apr 2025 09:11:21 +0000 Subject: [PATCH 05/16] lint issue --- python/tvm/topi/index_put.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/topi/index_put.py b/python/tvm/topi/index_put.py index 3a439bad5c4d..aa447e50f6e9 100644 --- a/python/tvm/topi/index_put.py +++ b/python/tvm/topi/index_put.py @@ -75,7 +75,6 @@ def gen_ir(data_ptr, index_ptrs, values_ptr, out_ptr, reduce_func): values = ib.buffer_ptr(values_ptr) out = ib.buffer_ptr(out_ptr) - # Copy initial input data to output with ib.for_range(0, full_range, "i", kind="parallel") as i: out[i] = data[i] @@ -116,4 +115,5 @@ def add_func(dst_ptr, dst_index, update): out_buffers=[out_buf], name="index_put.generic", tag="index_put.generic", - ) \ No newline at end of file + ) + From 58af85528319eb3e113ecc28d1820caf2a5401f6 Mon Sep 17 00:00:00 2001 From: Pratheesh Date: Mon, 21 Apr 2025 09:25:18 +0000 Subject: [PATCH 06/16] whitespace issue --- python/tvm/topi/index_put.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/tvm/topi/index_put.py b/python/tvm/topi/index_put.py index aa447e50f6e9..861e174f52a9 100644 --- a/python/tvm/topi/index_put.py +++ b/python/tvm/topi/index_put.py @@ -116,4 +116,3 @@ def add_func(dst_ptr, dst_index, update): name="index_put.generic", tag="index_put.generic", ) - From e4f2c94aa58912b04af083e2f13245c7682fc2c9 Mon Sep 17 00:00:00 2001 From: Pratheesh Date: Mon, 21 Apr 2025 09:25:18 +0000 Subject: [PATCH 07/16] whitespace issue --- python/tvm/topi/index_put.py | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/python/tvm/topi/index_put.py b/python/tvm/topi/index_put.py index 861e174f52a9..41cbec502832 100644 --- a/python/tvm/topi/index_put.py +++ b/python/tvm/topi/index_put.py @@ -1,6 +1,6 @@ # Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information +# or more contrir_builderutor license agreements. See the NOTICE file +# distrir_builderuted with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance @@ -9,7 +9,7 @@ # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an +# software distrir_builderuted under the License is distrir_builderuted on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations @@ -18,7 +18,6 @@ from tvm import te from tvm import tir from . import utils -from .math import cast def index_put(data, indices, values, accumulate=False): @@ -68,17 +67,17 @@ def index_put(data, indices, values, accumulate=False): raise ValueError("All index tensors must have same length") def gen_ir(data_ptr, index_ptrs, values_ptr, out_ptr, reduce_func): - ib = tir.ir_builder.create() + ir_builder = tir.ir_builder.create() - data = ib.buffer_ptr(data_ptr) - indices = [ib.buffer_ptr(idx) for idx in index_ptrs] - values = ib.buffer_ptr(values_ptr) - out = ib.buffer_ptr(out_ptr) + data = ir_builder.buffer_ptr(data_ptr) + indices = [ir_builder.buffer_ptr(idx) for idx in index_ptrs] + values = ir_builder.buffer_ptr(values_ptr) + out = ir_builder.buffer_ptr(out_ptr) - with ib.for_range(0, full_range, "i", kind="parallel") as i: + with ir_builder.for_range(0, full_range, "i", kind="parallel") as i: out[i] = data[i] - with ib.for_range(0, index_len, "k", kind="parallel") as k: + with ir_builder.for_range(0, index_len, "k", kind="parallel") as k: # Calculate multi-dimensional index flat_index = 0 stride = 1 @@ -91,7 +90,7 @@ def gen_ir(data_ptr, index_ptrs, values_ptr, out_ptr, reduce_func): reduce_func(out, flat_index, values[k]) - return ib.get() + return ir_builder.get() def update_func(dst_ptr, dst_index, update): dst_ptr[dst_index] = update From cb6facfb9e8b2e2be7409dc3ae8b4e478a2fc2d2 Mon Sep 17 00:00:00 2001 From: Pratheesh Date: Mon, 21 Apr 2025 10:25:42 +0000 Subject: [PATCH 08/16] whitespace issue --- src/relax/op/tensor/manipulate.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index 09c67b2fc833..33616279409d 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -2046,7 +2046,7 @@ StructInfo InferStructInfoIndexPut(const Call& call, const BlockBuilder& ctx) { } // Check that the number of index tensors matches data dimensions - if (!data_sinfo->IsUnknownNdim() && indices_tensors.size() != static_cast(data_sinfo->ndim)) { + if (!data_sinfo->IsUnknownNdim()&&indices_tensors.size() != static_cast(data_sinfo->ndim)) { ctx->ReportFatal(Diagnostic::Error(call) << "IndexPut requires the number of index tensors (" << indices_tensors.size() << ") to match the data tensor dimensions (" << data_sinfo->ndim << ")"); From 54f058f3c9129b170c80f96e8e99e88607ed77b9 Mon Sep 17 00:00:00 2001 From: Pratheesh Date: Mon, 21 Apr 2025 10:59:11 +0000 Subject: [PATCH 09/16] line length error --- src/relax/op/tensor/manipulate.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index 33616279409d..bd94d2c3665e 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -2046,7 +2046,8 @@ StructInfo InferStructInfoIndexPut(const Call& call, const BlockBuilder& ctx) { } // Check that the number of index tensors matches data dimensions - if (!data_sinfo->IsUnknownNdim()&&indices_tensors.size() != static_cast(data_sinfo->ndim)) { + if (!data_sinfo->IsUnknownNdim() && + indices_tensors.size() != static_cast(data_sinfo->ndim)) { ctx->ReportFatal(Diagnostic::Error(call) << "IndexPut requires the number of index tensors (" << indices_tensors.size() << ") to match the data tensor dimensions (" << data_sinfo->ndim << ")"); From 6a93b1ee48a72dde0d47cb597079ddc4ea6b09f8 Mon Sep 17 00:00:00 2001 From: Pratheesh Date: Mon, 21 Apr 2025 11:22:04 +0000 Subject: [PATCH 10/16] trailing space issue --- include/tvm/relax/attrs/manipulate.h | 10 +++++----- src/relax/op/tensor/manipulate.cc | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/include/tvm/relax/attrs/manipulate.h b/include/tvm/relax/attrs/manipulate.h index 8eb0b087bdd0..943d2f4d0d71 100644 --- a/include/tvm/relax/attrs/manipulate.h +++ b/include/tvm/relax/attrs/manipulate.h @@ -188,11 +188,11 @@ struct IndexPutAttrs : public tvm::AttrsNode { TVM_DECLARE_ATTRS(IndexPutAttrs, "relax.attrs.IndexPutAttrs") { TVM_ATTR_FIELD(accumulate) - .set_default(false) - .describe( - "Whether to accumulate (add) values rather than replace. " - "If true, performs tensor[indices] += values, " - "otherwise performs tensor[indices] = values."); + .set_default(false) + .describe( + "Whether to accumulate (add) values rather than replace. " + "If true, performs tensor[indices] += values, " + "otherwise performs tensor[indices] = values."); } }; // struct IndexPutAttrs diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index bd94d2c3665e..482ebe5cacb6 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -2047,7 +2047,7 @@ StructInfo InferStructInfoIndexPut(const Call& call, const BlockBuilder& ctx) { // Check that the number of index tensors matches data dimensions if (!data_sinfo->IsUnknownNdim() && - indices_tensors.size() != static_cast(data_sinfo->ndim)) { + indices_tensors.size() != static_cast(data_sinfo->ndim)) { ctx->ReportFatal(Diagnostic::Error(call) << "IndexPut requires the number of index tensors (" << indices_tensors.size() << ") to match the data tensor dimensions (" << data_sinfo->ndim << ")"); From 3ed69970b3a472286b331c07e313d8b685f32768 Mon Sep 17 00:00:00 2001 From: Pratheesh Date: Thu, 24 Apr 2025 06:30:58 +0000 Subject: [PATCH 11/16] modified conditions for parameters --- .../frontend/torch/base_fx_graph_translator.py | 16 ++++++++-------- python/tvm/relax/frontend/torch/fx_translator.py | 2 +- python/tvm/relax/op/manipulate.py | 2 +- python/tvm/topi/index_put.py | 9 ++++----- 4 files changed, 14 insertions(+), 15 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 20b66fdd4bfa..55e2546dbf02 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1118,21 +1118,21 @@ def _gather(self, node: fx.Node) -> relax.Var: index = self.env[node.args[2]] return self.block_builder.emit(relax.op.gather_elements(x, index, axis=dim)) - def _index_put_(self, node: fx.Node) -> relax.Var: + def _index_put(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) tensor = args[0] - indices = args[1] if len(args) > 1 else node.kwargs.get("indices", ()) + indices = args[1] if len(args) > 1 else node.kwargs.get("indices") values = args[2] if len(args) > 2 else node.kwargs.get("values") accumulate = args[3] if len(args) > 3 else node.kwargs.get("accumulate", False) - # Ensure accumulate is a boolean - if isinstance(accumulate, str): - accumulate = accumulate.lower() == "true" - elif not isinstance(accumulate, bool): - accumulate = bool(accumulate) + if indices is None or values is None: + raise ValueError("'indices and values' arguments are required for index_put operation") + + if not isinstance(accumulate, bool): + raise TypeError("'accumulate' must be a boolean value, got {}".format(type(accumulate))) if isinstance(indices, (list, tuple)): - indices = relax.Tuple(indices) if indices else relax.Tuple([]) + indices = relax.Tuple(indices) return self.block_builder.emit(relax.op.index_put(tensor, indices, values, accumulate)) def _index_tensor(self, node: fx.Node) -> relax.Var: diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index dbe42c8f4f17..6cb9db9b9f3d 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -801,7 +801,7 @@ def create_convert_map( "flatten": self._flatten, "flip": self._flip, "gather": self._gather, - "index_put_": self._index_put_, + "index_put_": self._index_put, "narrow": self._narrow, "numel": self._numel, "permute": self._permute, diff --git a/python/tvm/relax/op/manipulate.py b/python/tvm/relax/op/manipulate.py index 5183f281a598..259ad64ef53d 100644 --- a/python/tvm/relax/op/manipulate.py +++ b/python/tvm/relax/op/manipulate.py @@ -642,7 +642,7 @@ def index_put( ] """ if not isinstance(indices, (list, tuple)): - indices = RxTuple(indices) if indices else RxTuple([]) + indices = RxTuple(indices) return _ffi_api.index_put(data, indices, values, accumulate) # type: ignore diff --git a/python/tvm/topi/index_put.py b/python/tvm/topi/index_put.py index 41cbec502832..eb2e370897a6 100644 --- a/python/tvm/topi/index_put.py +++ b/python/tvm/topi/index_put.py @@ -18,6 +18,7 @@ from tvm import te from tvm import tir from . import utils +import math def index_put(data, indices, values, accumulate=False): @@ -56,14 +57,12 @@ def index_put(data, indices, values, accumulate=False): # Prepare ranges and strides shape = data.shape - full_range = 1 - for dim in shape: - full_range *= dim + full_range = math.prod(data.shape) # Check all indices have same length - index_len = indices[0].shape[0] + index_len = len(indices[0]) for idx in indices[1:]: - if not utils.equal_const_int(idx.shape[0], index_len): + if not utils.equal_const_int(len(idx), index_len): raise ValueError("All index tensors must have same length") def gen_ir(data_ptr, index_ptrs, values_ptr, out_ptr, reduce_func): From 5e73c9abe28ea64b3f22f9fe4ac65d67851b08c2 Mon Sep 17 00:00:00 2001 From: Pratheesh Date: Thu, 24 Apr 2025 06:39:49 +0000 Subject: [PATCH 12/16] modified base_fx --- python/tvm/relax/frontend/torch/base_fx_graph_translator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 55e2546dbf02..5dd78be4837d 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1127,7 +1127,7 @@ def _index_put(self, node: fx.Node) -> relax.Var: if indices is None or values is None: raise ValueError("'indices and values' arguments are required for index_put operation") - + if not isinstance(accumulate, bool): raise TypeError("'accumulate' must be a boolean value, got {}".format(type(accumulate))) From 72343736cda5c84e78543d407d14ad2de1ec71e9 Mon Sep 17 00:00:00 2001 From: Pratheesh Date: Thu, 24 Apr 2025 07:01:03 +0000 Subject: [PATCH 13/16] resolved conflicts --- python/tvm/relax/frontend/torch/exported_program_translator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 962ade7455f2..9f097ef8df11 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -433,7 +433,7 @@ def create_convert_map( "flip.default": self._flip, "gather.default": self._gather, "index.Tensor": self._index_tensor, - "index_put_.default": self._index_put_, + "index_put_.default": self._index_put, "narrow.default": self._narrow, "permute.default": self._permute, "repeat.default": self._repeat, From c5eda67cb33a4a1087446de9edb3f68f4f28a697 Mon Sep 17 00:00:00 2001 From: Pratheesh Date: Thu, 24 Apr 2025 07:02:14 +0000 Subject: [PATCH 14/16] removed trailing whitespace --- python/tvm/relax/op/manipulate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relax/op/manipulate.py b/python/tvm/relax/op/manipulate.py index 259ad64ef53d..13334d1479d9 100644 --- a/python/tvm/relax/op/manipulate.py +++ b/python/tvm/relax/op/manipulate.py @@ -562,7 +562,7 @@ def index_tensor(data: Expr, indices: Union[Expr, List[Expr]]) -> Expr: or a Python ``list`` / ``tuple`` that will be promoted to a tuple expression automatically. Each tensor must have an integer dtype. - + Returns ------- result : relax.Expr From 1e4bd42a01461c7741da36693697b15e7a31c81d Mon Sep 17 00:00:00 2001 From: Pratheesh Date: Thu, 24 Apr 2025 07:17:07 +0000 Subject: [PATCH 15/16] lint issue --- python/tvm/topi/index_put.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/topi/index_put.py b/python/tvm/topi/index_put.py index eb2e370897a6..e9f96a1c86e4 100644 --- a/python/tvm/topi/index_put.py +++ b/python/tvm/topi/index_put.py @@ -15,10 +15,10 @@ # specific language governing permissions and limitations # under the License. """IndexPut operator""" +import math from tvm import te from tvm import tir from . import utils -import math def index_put(data, indices, values, accumulate=False): From 344c6b12ae5fdcbb3aac2e583e5dd5b090b955c7 Mon Sep 17 00:00:00 2001 From: Pratheesh Date: Thu, 24 Apr 2025 08:54:56 +0000 Subject: [PATCH 16/16] unity issue --- python/tvm/topi/index_put.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/tvm/topi/index_put.py b/python/tvm/topi/index_put.py index e9f96a1c86e4..f51c6718ab99 100644 --- a/python/tvm/topi/index_put.py +++ b/python/tvm/topi/index_put.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. """IndexPut operator""" -import math from tvm import te from tvm import tir from . import utils @@ -57,7 +56,9 @@ def index_put(data, indices, values, accumulate=False): # Prepare ranges and strides shape = data.shape - full_range = math.prod(data.shape) + full_range = 1 + for dim in shape: + full_range *= dim # Check all indices have same length index_len = len(indices[0])