From 2b5b10a45c7e8f94b5de4e06fd5d415c5b924f0d Mon Sep 17 00:00:00 2001 From: Kavin-mcw Date: Thu, 5 Jun 2025 13:12:48 +0530 Subject: [PATCH 1/6] add support for bucketize --- include/tvm/relax/attrs/search.h | 12 +++++ .../torch/base_fx_graph_translator.py | 10 ++++ .../torch/exported_program_translator.py | 1 + .../tvm/relax/frontend/torch/fx_translator.py | 1 + python/tvm/relax/op/__init__.py | 2 +- python/tvm/relax/op/search.py | 25 ++++++++++ .../relax/transform/legalize_ops/search.py | 9 ++++ python/tvm/script/ir_builder/relax/ir.py | 2 + src/relax/op/tensor/search.cc | 48 +++++++++++++++++++ src/relax/op/tensor/search.h | 10 ++++ .../test_frontend_from_exported_program.py | 25 ++++++++++ tests/python/relax/test_frontend_from_fx.py | 22 +++++++++ 12 files changed, 166 insertions(+), 1 deletion(-) diff --git a/include/tvm/relax/attrs/search.h b/include/tvm/relax/attrs/search.h index b6adb3e437c6..89d575384deb 100644 --- a/include/tvm/relax/attrs/search.h +++ b/include/tvm/relax/attrs/search.h @@ -42,6 +42,18 @@ struct ArgmaxArgminAttrs : public tvm::AttrsNode { } }; // struct ArgmaxArgminAttrs +/*! \brief Attributes for bucketize operator */ +struct BucketizeAttrs : public tvm::AttrsNode { + bool out_int32; + bool right; + + TVM_DECLARE_ATTRS(BucketizeAttrs, "relax.attrs.BucketizeAttrs") { + TVM_ATTR_FIELD(out_int32).describe( + "Indicate the output data type. torch.int32 if True, torch.int64 otherwise "); + TVM_ATTR_FIELD(right).describe("Determines the behavior for values in boundaries "); + } +}; // struct BucketizeAttrs + } // namespace relax } // namespace tvm diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 485b7c088a15..e5a815270fc0 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1370,6 +1370,16 @@ def _where(self, node: fx.Node) -> relax.Var: y = self.env[node.args[2]] return self.block_builder.emit(relax.op.where(condition, x, y)) + def _bucketize(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + input = args[0] + boundaries = args[1] + + right = node.kwargs.get("right", False) + out_int32 = node.kwargs.get("out_int32", False) + + return self.block_builder.emit(relax.op.bucketize(input, boundaries, out_int32, right)) + ########## Manipulation ########## def _argsort(self, node: fx.Node) -> relax.Var: diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 57a6577eaf4a..85836ec148ae 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -482,6 +482,7 @@ def create_convert_map( "argmax.default": self._argmax_argmin(relax.op.argmax), "argmin.default": self._argmax_argmin(relax.op.argmin), "where.self": self._where, + "bucketize.Tensor": self._bucketize, # tensor manipulation "argsort.default": self._argsort, "broadcast_to.default": self._broadcast_to, diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 33abccbe5f85..be0fb0af9754 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -884,6 +884,7 @@ def create_convert_map( "argmax": self._argmax_argmin(relax.op.argmax), "argmin": self._argmax_argmin(relax.op.argmin), "where": self._where, + "bucketize": self._bucketize, # tensor manipulation "argsort": self._argsort, "broadcast_to": self._broadcast_to, diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py index 9388831fce31..fd3672368b68 100644 --- a/python/tvm/relax/op/__init__.py +++ b/python/tvm/relax/op/__init__.py @@ -115,7 +115,7 @@ from .mask import masked_fill from .qdq import dequantize, quantize from .sampling import multinomial_from_uniform -from .search import argmax, argmin, where +from .search import argmax, argmin, where, bucketize from .set import nonzero, unique from .sorting import argsort, sort, topk from .statistical import cumprod, cumsum, max, mean, min, prod, std, sum, variance diff --git a/python/tvm/relax/op/search.py b/python/tvm/relax/op/search.py index b097d78234d5..016b22b9b936 100644 --- a/python/tvm/relax/op/search.py +++ b/python/tvm/relax/op/search.py @@ -102,3 +102,28 @@ def argmin(x: Expr, axis: Optional[int] = None, keepdims: bool = False) -> Expr: The computed result. """ return _ffi_api.argmin(x, axis, keepdims) # type: ignore + + +def bucketize(input_tensor, boundaries, out_int32=False, right=False): + """Returns the indices of the buckets to which each value in the input belongs. + + Parameters + ---------- + input_tensor : relax.Expr + N-D tensor containing the search values. + + boundaries : relax.Expr + 1-D tensor, must contain a strictly increasing sequence, or the return value is undefined. + + out_int32 : Optional[bool] + Indicate the output data type. int32 if True, int64 otherwise. Default=False + + right : Optional[bool] + Determines the behavior for values in boundaries. Similar to torch.bucketize + + Returns + ------- + result : relax.Expr + The computed result with same shape as input_tensor. + """ + return _ffi_api.bucketize(input_tensor, boundaries, out_int32, right) diff --git a/python/tvm/relax/transform/legalize_ops/search.py b/python/tvm/relax/transform/legalize_ops/search.py index 19ff00774ca0..a49bcabdab01 100644 --- a/python/tvm/relax/transform/legalize_ops/search.py +++ b/python/tvm/relax/transform/legalize_ops/search.py @@ -39,3 +39,12 @@ def argmax_argmin_call_te(bb: BlockBuilder, call: Call) -> Expr: register_legalize("relax.argmax", _argmax_argmin(topi.argmax)) register_legalize("relax.argmin", _argmax_argmin(topi.argmin)) + + +@register_legalize("relax.bucketize") +def _bucketize(bb, call): + input = call.args[0] + boundaries = call.args[1] + right = call.attrs.right + print(input.struct_info.dtype) + return bb.call_te(topi.searchsorted, boundaries, input, right, input.struct_info.dtype) diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index 1e48e9ea1ad7..43590dfa25e3 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -58,6 +58,7 @@ bitwise_or, bitwise_xor, broadcast_to, + bucketize, builtin, call_builtin_with_ctx, call_dps_packed, @@ -731,6 +732,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "bitwise_or", "bitwise_xor", "broadcast_to", + "bucketize", "builtin", "call_inplace_packed", "call_packed", diff --git a/src/relax/op/tensor/search.cc b/src/relax/op/tensor/search.cc index 83e0e246b1bf..e65b16c712cc 100644 --- a/src/relax/op/tensor/search.cc +++ b/src/relax/op/tensor/search.cc @@ -30,6 +30,54 @@ namespace tvm { namespace relax { +/* relax.bucketize */ +TVM_REGISTER_NODE_TYPE(BucketizeAttrs); + +Expr bucketize(Expr input_tensor, Expr boundaries, bool out_int32, bool right) { + auto attrs = make_object(); + attrs->out_int32 = out_int32; + attrs->right = right; + static const Op& op = Op::Get("relax.bucketize"); + return Call(op, {input_tensor, boundaries}, Attrs(attrs), {}); +} + +TVM_FFI_REGISTER_GLOBAL("relax.op.bucketize").set_body_typed(bucketize); + +StructInfo InferStructInfoBucketize(const Call& call, const BlockBuilder& ctx) { + Array input_sinfo = GetInputTensorStructInfo(call, ctx); + TensorStructInfo input_tensor_info = input_sinfo[0]; + TensorStructInfo boundaries_info = input_sinfo[1]; + + if (!boundaries_info->IsUnknownNdim() && boundaries_info->ndim != 1) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Bucketize requires boundary to be 1-D array but got " + << boundaries_info->ndim); + } + + auto attrs = call->attrs.as(); + DataType out_dtype; + out_dtype = DataType::Int(64); + if (attrs->out_int32) { + out_dtype = DataType::Int(32); + } + + const auto* data_shape = input_tensor_info->shape.as(); + if (data_shape) { + return TensorStructInfo(ShapeExpr(data_shape->values), out_dtype, input_tensor_info->vdevice); + } + return TensorStructInfo(out_dtype, input_tensor_info->ndim, input_tensor_info->vdevice); +} + +TVM_REGISTER_OP("relax.bucketize") + .set_num_inputs(2) + .add_argument("input_tensor", "Tensor", + " N-D tensor or a Scalar containing the search value(s).") + .add_argument("boundaries", "Tensor", + "1-D tensor, must contain a strictly increasing sequence, or the return value is " + "undefined.") + .set_attr("FInferStructInfo", InferStructInfoBucketize) + .set_attr("FPurity", Bool(true)); + /* relax.where */ Expr where(Expr condition, Expr x1, Expr x2) { static const Op& op = Op::Get("relax.where"); diff --git a/src/relax/op/tensor/search.h b/src/relax/op/tensor/search.h index eb40171790a3..333b5afe76c7 100644 --- a/src/relax/op/tensor/search.h +++ b/src/relax/op/tensor/search.h @@ -30,6 +30,16 @@ namespace tvm { namespace relax { +/*! + * \brief Returns the indices of the buckets to which each value in the input belongs. + * \param input_tensor N-D tensor containing the search values. + * \param boundaries 1-D tensor, must contain a strictly increasing sequence. + * \param out_int32 Indicate the output data type. int32 if True, int64 otherwise. + * \param right Determines the behavior for values in boundaries. Similar to torch.bucketize + + * \return The computed result with the same shape as input. + */ +Expr bucketize(Expr input_tensor, Expr boundaries, bool out_int32, bool right); /*! * \brief Selecting elements from either the input tensors depending on the value of the diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index dd04833e07b8..e2e57e690961 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -5427,6 +5427,31 @@ def main( verify_model(Where(), (condition, x, y), {}, Expected) +def test_bucketize(): + class Bucketize(Module): + def forward(self, input_tensor, boundaries): + return torch.bucketize(input_tensor, boundaries) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + input: R.Tensor((20,), dtype="int64"), boundaries: R.Tensor((10,), dtype="int64") + ) -> R.Tuple(R.Tensor((20,), dtype="int64")): + with R.dataflow(): + lv: R.Tensor((20,), dtype="int64") = R.bucketize( + input, boundaries, out_int32=False, right=False + ) + gv: R.Tuple(R.Tensor((20,), dtype="int64")) = (lv,) + R.output(gv) + return gv + + input_tensor = torch.arange(0, 20) + boundaries = torch.arange(0, 20, 2) + + verify_model(Bucketize(), (input_tensor, boundaries), {}, Expected) + + def test_argsort(): class Argsort(Module): def forward(self, x): diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index f33b55085825..a3f1e823398c 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -5827,6 +5827,28 @@ def main( ) +def test_bucketize(): + class Bucketize(Module): + def forward(self, input_tensor, boundaries): + return torch.bucketize(input_tensor, boundaries) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + input: R.Tensor((5, 3), dtype="float32"), boundaries: R.Tensor((10,), dtype="float32") + ) -> R.Tensor((5, 3), dtype="int64"): + with R.dataflow(): + lv: R.Tensor((5, 3), dtype="int64") = R.bucketize( + input, boundaries, out_int32=False, right=False + ) + gv: R.Tensor((5, 3), dtype="int64") = lv + R.output(gv) + return gv + + verify_model(Bucketize(), [([5, 3], "float32"), ([10], "float32")], {}, Expected) + + def test_argsort(): class Argsort(Module): def forward(self, x): From 46d2bb42f114ba3c10a4bdae77dfff624f571da9 Mon Sep 17 00:00:00 2001 From: Kavin-mcw Date: Thu, 5 Jun 2025 13:42:19 +0530 Subject: [PATCH 2/6] fix lint issue --- python/tvm/relax/frontend/torch/base_fx_graph_translator.py | 4 ++-- python/tvm/relax/transform/legalize_ops/search.py | 5 ++--- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index e5a815270fc0..ea2578acdb2b 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1372,13 +1372,13 @@ def _where(self, node: fx.Node) -> relax.Var: def _bucketize(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) - input = args[0] + input_tensor = args[0] boundaries = args[1] right = node.kwargs.get("right", False) out_int32 = node.kwargs.get("out_int32", False) - return self.block_builder.emit(relax.op.bucketize(input, boundaries, out_int32, right)) + return self.block_builder.emit(relax.op.bucketize(input_tensor, boundaries, out_int32, right)) ########## Manipulation ########## diff --git a/python/tvm/relax/transform/legalize_ops/search.py b/python/tvm/relax/transform/legalize_ops/search.py index a49bcabdab01..063eeda516d1 100644 --- a/python/tvm/relax/transform/legalize_ops/search.py +++ b/python/tvm/relax/transform/legalize_ops/search.py @@ -43,8 +43,7 @@ def argmax_argmin_call_te(bb: BlockBuilder, call: Call) -> Expr: @register_legalize("relax.bucketize") def _bucketize(bb, call): - input = call.args[0] + input_tensor = call.args[0] boundaries = call.args[1] right = call.attrs.right - print(input.struct_info.dtype) - return bb.call_te(topi.searchsorted, boundaries, input, right, input.struct_info.dtype) + return bb.call_te(topi.searchsorted, boundaries, input_tensor, right, input_tensor.struct_info.dtype) From f77e164576519d04c709a08021d1f0d21a153a54 Mon Sep 17 00:00:00 2001 From: Kavin-mcw Date: Thu, 5 Jun 2025 13:49:52 +0530 Subject: [PATCH 3/6] Fix lint issue --- python/tvm/relax/frontend/torch/base_fx_graph_translator.py | 4 +++- python/tvm/relax/transform/legalize_ops/search.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index ea2578acdb2b..c055d89f0476 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1378,7 +1378,9 @@ def _bucketize(self, node: fx.Node) -> relax.Var: right = node.kwargs.get("right", False) out_int32 = node.kwargs.get("out_int32", False) - return self.block_builder.emit(relax.op.bucketize(input_tensor, boundaries, out_int32, right)) + return self.block_builder.emit( + relax.op.bucketize(input_tensor, boundaries, out_int32, right) + ) ########## Manipulation ########## diff --git a/python/tvm/relax/transform/legalize_ops/search.py b/python/tvm/relax/transform/legalize_ops/search.py index 063eeda516d1..89fddb4b95d8 100644 --- a/python/tvm/relax/transform/legalize_ops/search.py +++ b/python/tvm/relax/transform/legalize_ops/search.py @@ -46,4 +46,6 @@ def _bucketize(bb, call): input_tensor = call.args[0] boundaries = call.args[1] right = call.attrs.right - return bb.call_te(topi.searchsorted, boundaries, input_tensor, right, input_tensor.struct_info.dtype) + return bb.call_te( + topi.searchsorted, boundaries, input_tensor, right, input_tensor.struct_info.dtype + ) From 1545a3e68baf8d536d3e30dc3a9669f9354889c4 Mon Sep 17 00:00:00 2001 From: Kavin-mcw Date: Fri, 13 Jun 2025 12:48:29 +0530 Subject: [PATCH 4/6] Add GPU code for bucketize --- .../tvm/relax/backend/dispatch_sort_scan.py | 12 +++ python/tvm/topi/gpu/sort.py | 89 ++++++++++++++++++- 2 files changed, 100 insertions(+), 1 deletion(-) diff --git a/python/tvm/relax/backend/dispatch_sort_scan.py b/python/tvm/relax/backend/dispatch_sort_scan.py index f8a7dfe2037d..1dac0bf230f3 100644 --- a/python/tvm/relax/backend/dispatch_sort_scan.py +++ b/python/tvm/relax/backend/dispatch_sort_scan.py @@ -75,6 +75,18 @@ def visit_call_(self, call: relax.Call) -> relax.Expr: if not isinstance(call.op, Op): return super().visit_call_(call) + if call.op.name == "relax.bucketize": + input_tensor = call.args[0] + boundaries = call.args[1] + right = call.attrs.right + tgt = self._get_target(call.struct_info) + te_func = topi.searchsorted + with tgt: + if self.is_gpu_target(tgt): + te_func = topi.gpu.searchsorted + return self.builder_.call_te( + te_func, boundaries, input_tensor, right, input_tensor.struct_info.dtype + ) if call.op.name == "relax.sort": tgt = self._get_target(call.struct_info) te_func = topi.sort diff --git a/python/tvm/topi/gpu/sort.py b/python/tvm/topi/gpu/sort.py index 71854e43997a..eb48da0a022a 100644 --- a/python/tvm/topi/gpu/sort.py +++ b/python/tvm/topi/gpu/sort.py @@ -20,8 +20,9 @@ from tvm import te from ..transform import strided_slice, transpose -from ..utils import ceil_div, swap +from ..utils import ceil_div, swap, prod from ..math import cast, ceil_log2 +from ..searchsorted import binary_search def _get_threads(ib, nthread_tx, nthread_bx, nthread_by): @@ -937,3 +938,89 @@ def f_compute(ins, outs): out = out[1] return out + + +def searchsorted(sorted_sequence, values, right=False, out_dtype="int64"): + """Find indices where elements should be inserted to maintain order. + If `sorted_sequence` is N-dimensional, the innermost dimension of + `values` are searched in the corresponding dimension of `sorted_sequence`. + + This implementation is optimized for GPU execution. + + Parameters + ---------- + sorted_sequence : te.Tensor + N-D or 1-D Tensor, containing monotonically increasing sequence + on the innermost dimension. + + values : te.Tensor + N-D Tensor containing the search values. When `sorted_sequence` is 1-D, + the shape of `values` can be arbitrary. Otherwise, ranks of `sorted_sequence` + and `values` must be the same, and outer N-1 axes must have the same size. + + right : bool, optional + Controls which index is returned if a value lands exactly on one of sorted values. If + False (side='left'), the index of the first suitable location found is given. If true + (side='right'), return the last such index. + + out_dtype : string, optional + The data type of the output indices. + + Returns + ------- + indices : te.Tensor + Tensor with same shape as values, representing the indices of + elements of `values` if they are inserted in `sorted_sequence`. + """ + if len(sorted_sequence.shape) > 1: + for i in range(len(values.shape) - 1): + assert ( + values.shape[i] == sorted_sequence.shape[i] + ), "Outer dimensions of sorted_sequence and values must match for N-D searchsorted" + + def ir(sorted_sequence_buf, values_buf, indices_buf): + ib = tvm.tir.ir_builder.create() + sorted_sequence_shape = sorted_sequence_buf.shape + values_shape = values_buf.shape + num_search = prod(values_shape) + search_range = sorted_sequence_shape[-1] + + sorted_sequence_ptr = ib.buffer_ptr(sorted_sequence_buf) + values_ptr = ib.buffer_ptr(values_buf) + indices_ptr = ib.buffer_ptr(indices_buf) + + max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) + nthread_tx = max_threads + nthread_bx = ceil_div(num_search, nthread_tx) + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + tid = bx * nthread_tx + tx + + with ib.if_scope(tid < num_search): + if len(sorted_sequence_shape) == 1: + sequence_offset = 0 + else: + sequence_id = tid // values_shape[-1] + sequence_offset = sequence_id * search_range + + indices_ptr[tid] = binary_search( + ib, + sequence_offset, + search_range, + sorted_sequence_ptr, + values_ptr[tid], + right, + out_dtype, + ) + + return ib.get() + + return te.extern( + values.shape, + [sorted_sequence, values], + lambda ins, outs: ir(ins[0], ins[1], outs[0]), + name="searchsorted_gpu", + dtype=out_dtype, + ) From 8165c1df37235f41c897c6eb59ea94e9264e7544 Mon Sep 17 00:00:00 2001 From: Kavin mcw Date: Mon, 30 Jun 2025 12:04:09 +0530 Subject: [PATCH 5/6] Resolve merge conflict --- include/tvm/relax/attrs/search.h | 35 ++++++++++++++++++++++---------- src/relax/op/tensor/search.cc | 11 +++++++--- 2 files changed, 32 insertions(+), 14 deletions(-) diff --git a/include/tvm/relax/attrs/search.h b/include/tvm/relax/attrs/search.h index 89d575384deb..6fdbe59cea74 100644 --- a/include/tvm/relax/attrs/search.h +++ b/include/tvm/relax/attrs/search.h @@ -30,28 +30,41 @@ namespace tvm { namespace relax { /*! \brief Attributes for search operators */ -struct ArgmaxArgminAttrs : public tvm::AttrsNode { +struct ArgmaxArgminAttrs : public AttrsNodeReflAdapter { Optional axis; bool keepdims; - TVM_DECLARE_ATTRS(ArgmaxArgminAttrs, "relax.attrs.ArgmaxArgminAttrs") { - TVM_ATTR_FIELD(axis).describe("The axis along which to perform the argmin/argmax."); - TVM_ATTR_FIELD(keepdims).describe( - "If this is set to `True`, the reduced axis is left in the result as dimension with size " - "one."); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("axis", &ArgmaxArgminAttrs::axis, + "The axis along which to perform the argmin/argmax.") + .def_ro("keepdims", &ArgmaxArgminAttrs::keepdims, + "If this is set to `True`, the reduced axis is left in the result as dimension " + "with size " + "one."); } + + static constexpr const char* _type_key = "relax.attrs.ArgmaxArgminAttrs"; + TVM_FFI_DECLARE_FINAL_OBJECT_INFO(ArgmaxArgminAttrs, BaseAttrsNode); }; // struct ArgmaxArgminAttrs /*! \brief Attributes for bucketize operator */ -struct BucketizeAttrs : public tvm::AttrsNode { +struct BucketizeAttrs : public tvm::AttrsNodeReflAdapter { bool out_int32; bool right; - TVM_DECLARE_ATTRS(BucketizeAttrs, "relax.attrs.BucketizeAttrs") { - TVM_ATTR_FIELD(out_int32).describe( - "Indicate the output data type. torch.int32 if True, torch.int64 otherwise "); - TVM_ATTR_FIELD(right).describe("Determines the behavior for values in boundaries "); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("out_int32", &BucketizeAttrs::out_int32, + "Indicate the output datatype, int32 if True, int64 otherwise.") + .def_ro("right", &BucketizeAttrs::right, + "Determines the behavior for values in boundaries"); } + + static constexpr const char* _type_key = "relax.attrs.BucketizeAttrs"; + TVM_FFI_DECLARE_FINAL_OBJECT_INFO(BucketizeAttrs, BaseAttrsNode); }; // struct BucketizeAttrs } // namespace relax diff --git a/src/relax/op/tensor/search.cc b/src/relax/op/tensor/search.cc index e65b16c712cc..2babb7f490cf 100644 --- a/src/relax/op/tensor/search.cc +++ b/src/relax/op/tensor/search.cc @@ -30,15 +30,20 @@ namespace tvm { namespace relax { +TVM_FFI_STATIC_INIT_BLOCK({ + ArgmaxArgminAttrs::RegisterReflection(); + BucketizeAttrs::RegisterReflection(); +}); + /* relax.bucketize */ TVM_REGISTER_NODE_TYPE(BucketizeAttrs); Expr bucketize(Expr input_tensor, Expr boundaries, bool out_int32, bool right) { auto attrs = make_object(); - attrs->out_int32 = out_int32; - attrs->right = right; + attrs->out_int32 = std::move(out_int32); + attrs->right = std::move(right); static const Op& op = Op::Get("relax.bucketize"); - return Call(op, {input_tensor, boundaries}, Attrs(attrs), {}); + return Call(op, {std::move(input_tensor), std::move(boundaries)}, Attrs(attrs), {}); } TVM_FFI_REGISTER_GLOBAL("relax.op.bucketize").set_body_typed(bucketize); From 55e8540d206f247f97fc4ccec85d2242bba9758d Mon Sep 17 00:00:00 2001 From: kavin-mcw Date: Mon, 30 Jun 2025 13:07:57 +0530 Subject: [PATCH 6/6] Fix lint issue --- src/relax/op/tensor/search.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relax/op/tensor/search.cc b/src/relax/op/tensor/search.cc index 2babb7f490cf..3e0236fc28e5 100644 --- a/src/relax/op/tensor/search.cc +++ b/src/relax/op/tensor/search.cc @@ -31,7 +31,7 @@ namespace tvm { namespace relax { TVM_FFI_STATIC_INIT_BLOCK({ - ArgmaxArgminAttrs::RegisterReflection(); + ArgmaxArgminAttrs::RegisterReflection(); BucketizeAttrs::RegisterReflection(); });