From 0c8c9046419822a0028611564f01dca778d4c590 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 16 Mar 2025 02:34:34 -0400 Subject: [PATCH 001/105] suddenly copy.default is unsupported --- tests/python/relax/test_from_exported_to_cuda.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index 6cc12370d648..7f0204cd2e1f 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -202,6 +202,11 @@ def forward(self, x): assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) +@tvm.testing.parametrize_targets("cuda") +def test_copy_fails_(target, dev): + assert False, "test_copy_fails_ indeed fails" + + @tvm.testing.parametrize_targets("cuda") def test_upsample_with_size(target, dev): """ From f7f063786d0f26ba1ed79b6d1ffaaf858055f54b Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 16 Mar 2025 14:58:58 -0400 Subject: [PATCH 002/105] wip --- .../torch/base_fx_graph_translator.py | 20 +++++++++++++++++++ .../torch/exported_program_translator.py | 1 + 2 files changed, 21 insertions(+) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 6bbc9d5de618..40342a830674 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1021,6 +1021,7 @@ def _scatter(self, node: fx.Node) -> relax.Var: return self.block_builder.emit(relax.op.scatter_elements(x, index, src, axis=dim)) def _split(self, node: fx.Node) -> relax.Var: + """ torch.split with split_size passed as an argument""" x = self.env[node.args[0]] split_size = node.args[1] dim = node.args[2] if len(node.args) > 2 else node.kwargs.get("dim", 0) @@ -1031,6 +1032,25 @@ def _split(self, node: fx.Node) -> relax.Var: n_section.append(s + cum_sum) else: n_section = (self.shape_of(x)[dim].value + split_size - 1) // split_size + print("calling split with n_section", n_section) + print("calling split with dim", dim) + return self.block_builder.emit(relax.op.split(x, n_section, dim)) + + def _split_with_sizes(self, node: fx.Node) -> relax.Var: + """ torch.split with a list of section sizes passed as an argument""" + # TODO + x = self.env[node.args[0]] + split_size = node.args[1] + dim = node.args[2] if len(node.args) > 2 else node.kwargs.get("dim", 0) + if isinstance(split_size, (list, tuple)): + n_section = [] + for s in split_size[:-1]: + cum_sum = 0 if not n_section else n_section[-1] + n_section.append(s + cum_sum) + else: + n_section = (self.shape_of(x)[dim].value + split_size - 1) // split_size + print("calling split with n_section", n_section) + print("calling split with dim", dim) return self.block_builder.emit(relax.op.split(x, n_section, dim)) def _squeeze(self, node: fx.Node) -> relax.Var: diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index bc7a4c4cb046..834002a02170 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -329,6 +329,7 @@ def create_convert_map( "select.int": self._select, "slice.Tensor": self._slice, "split.Tensor": self._split, + "split_with_sizes.default": self._split_with_sizes, "squeeze.default": self._squeeze, "squeeze.dim": self._squeeze, "take.default": self._take, From 7e4cf058ba43bd49904905362301dd26d50a6b68 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 16 Mar 2025 16:48:01 -0400 Subject: [PATCH 003/105] Able to split uneven tensors! Remaining TODOs Make sure that also works if we input a list of indices Write tests at the py and cpp level Cleanup --- include/tvm/topi/transform.h | 32 +++++++++++-------- python/tvm/relax/op/manipulate.py | 1 + .../transform/legalize_ops/manipulate.py | 15 +++++---- .../framework/tensorrt/transform_tensorrt.cc | 2 ++ src/relax/op/tensor/manipulate.cc | 4 +++ src/topi/transform.cc | 7 ++-- 6 files changed, 39 insertions(+), 22 deletions(-) diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index faacd2ce5760..62f69d9c825e 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -575,8 +575,19 @@ inline Tensor stack(const Array& inputs, int axis = 0, std::string name * * \return A Tensor whose op member is the split operation */ -inline Array split(const Tensor& x, Array split_indices, int axis, - std::string name = "T_split", std::string tag = kInjective) { +inline Array split_indices_array(const Tensor& x, Array split_indices, int axis, + std::string name = "T_split", + std::string tag = kInjective) { + printf("we are in the transform'h's split\n"); + // int x11 = 10; + // int y11 = 5; + // while (y11) { + // x11 -= 2; + // y11--; + // } + // int z11 = 50 / x11; + // printf("z11 is %d\n", z11); + if (axis < 0) { axis += static_cast(x->shape.size()); } @@ -968,9 +979,10 @@ inline Tensor strided_slice(const Tensor& x, const Array& begin, const * * \return A Tensor whose op member is the split operation */ -inline Array split_sections(const Tensor& x, int num_sections, int axis, - std::string name = "T_split_sections", - std::string tag = kInjective) { +inline Array split_n_sections(const Tensor& x, int num_sections, int axis, + std::string name = "T_split_sections", + std::string tag = kInjective) { + printf("We are in transform.h's splits_sections\n"); if (axis < 0) { axis += static_cast(x->shape.size()); } @@ -980,14 +992,8 @@ inline Array split_sections(const Tensor& x, int num_sections, int axis, ICHECK_GT(num_sections, 0) << "Slice count must be > 0"; - if (auto node = src_axis_size.as()) { - ICHECK_EQ(node->value % num_sections, 0) - << "num_sections must be an integer factor of the size of axis " << axis << " (" - << node->value << ")"; - } - Array split_indices; - auto seg_size = indexdiv(src_axis_size, num_sections); + auto seg_size = indexdiv(src_axis_size + num_sections - 1, num_sections); for (int i = 0; i < num_sections; ++i) { // region at index 0 is added by split() if (i != 0) { @@ -995,7 +1001,7 @@ inline Array split_sections(const Tensor& x, int num_sections, int axis, } } - return split(x, split_indices, axis, name, tag); + return split_indices_array(x, split_indices, axis, name, tag); } /*! diff --git a/python/tvm/relax/op/manipulate.py b/python/tvm/relax/op/manipulate.py index 0f6e537ab3d6..85f08b7f98bd 100644 --- a/python/tvm/relax/op/manipulate.py +++ b/python/tvm/relax/op/manipulate.py @@ -253,6 +253,7 @@ def split( """ if isinstance(indices_or_sections, int): indices_or_sections = IntImm("int64", indices_or_sections) + print("CALLING _ffi_api.split !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") return _ffi_api.split(x, indices_or_sections, axis) # type: ignore diff --git a/python/tvm/relax/transform/legalize_ops/manipulate.py b/python/tvm/relax/transform/legalize_ops/manipulate.py index c71a41dc1c2d..cefbba55fce3 100644 --- a/python/tvm/relax/transform/legalize_ops/manipulate.py +++ b/python/tvm/relax/transform/legalize_ops/manipulate.py @@ -107,18 +107,19 @@ def _permute_dims(bb: BlockBuilder, call: Call) -> Expr: @register_legalize("relax.split") def _split(bb: BlockBuilder, call: Call) -> Expr: + print("is this the split op??????????????????????????????????/") if isinstance(call.attrs.indices_or_sections, tir.IntImm): indices_or_sections = call.attrs.indices_or_sections.value modulo = tvm.arith.Analyzer().simplify( call.args[0].struct_info.shape.values[call.attrs.axis] % indices_or_sections ) - if isinstance(modulo, tir.IntImm): - if modulo != 0: - logging.info( - "Split cannot be legalized by TOPI when the axis being split has " - "length that not divisible by the input number of section." - ) - return call + # if isinstance(modulo, tir.IntImm): + # if modulo != 0: + # logging.info( + # "Split cannot be legalized by TOPI when the axis being split has " + # "length that not divisible by the input number of section." + # ) + # return call else: indices_or_sections = call.attrs.indices_or_sections return bb.call_te(topi.split, call.args[0], indices_or_sections, call.attrs.axis) diff --git a/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc b/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc index 94dfec7ea621..92d090946cdb 100644 --- a/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc +++ b/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc @@ -782,6 +782,8 @@ Expr RewriteShapeLike(BlockBuilder builder, const Var& var, const Call& src_call Expr RewriteSplit(BlockBuilder builder, const Var& var, const Call& src_call, const Map& new_calls, const String& config) { + printf("INSIDE REWRITESPLIT !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!11\n"); + std::cout << "Why is this not building?" << std::endl; const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; const auto& input_shape = ExprUtils::GetShape(call->args[0]); const auto* src_attrs = src_call->attrs.as(); diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index cb738db363ee..f40f0ac36e0a 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -824,7 +824,10 @@ TVM_REGISTER_OP("relax.reshape") /* relax.split */ TVM_REGISTER_NODE_TYPE(SplitAttrs); +#include + Expr split(Expr x, Variant> indices_or_sections, int axis) { + printf("INSIDE SPLIT!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n"); ObjectPtr attrs = make_object(); if (const auto* indices = indices_or_sections.as()) { for (int i = 0; i < static_cast(indices->size()); ++i) { @@ -848,6 +851,7 @@ Expr split(Expr x, Variant> indices_or_sections, int axis) attrs->axis = axis; static const Op& op = Op::Get("relax.split"); + printf("CALLING OP !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n"); return Call(op, {std::move(x)}, Attrs(attrs), {}); } diff --git a/src/topi/transform.cc b/src/topi/transform.cc index 2e0fde3b289f..a0df64faa291 100644 --- a/src/topi/transform.cc +++ b/src/topi/transform.cc @@ -82,11 +82,14 @@ TVM_REGISTER_GLOBAL("topi.ndarray_size").set_body([](TVMArgs args, TVMRetValue* *rv = ndarray_size(args[0], args[1]); }); +#include TVM_REGISTER_GLOBAL("topi.split").set_body([](TVMArgs args, TVMRetValue* rv) { + printf("we are in transform.cc's topi.split, called by _split python func in manipulate.py\n"); + if (args[1].type_code() == kDLInt || args[1].type_code() == kDLUInt) { - *rv = split_sections(args[0], args[1], args[2]); + *rv = split_n_sections(args[0], args[1], args[2]); } else { - *rv = split(args[0], args[1], args[2]); + *rv = split_indices_array(args[0], args[1], args[2]); } }); From dcbee0c81befa6a371215832bc50e8db07a12d9e Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 16 Mar 2025 17:31:54 -0400 Subject: [PATCH 004/105] split size test passes! --- .../relax/test_from_exported_to_cuda.py | 45 +++++++++++++++---- 1 file changed, 36 insertions(+), 9 deletions(-) diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index 7f0204cd2e1f..be438cac0e05 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -15,6 +15,9 @@ # specific language governing permissions and limitations # under the License. +import sys +sys.path.append('/ssd1/htalendr/tvm/python') # Refer to local TVM build + import tvm from tvm import relax import tvm.testing @@ -50,10 +53,17 @@ def assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, tar gpu_params = [tvm.nd.array(p, dev) for p in tvm_params["main"]] gpu_out = vm["main"](gpu_data, *gpu_params) - pytorch_out = torch_module(torch_data).detach().numpy() - actual = gpu_out[0].numpy() - desired = pytorch_out - np.testing.assert_allclose(actual=actual, desired=desired, rtol=1e-5, atol=1e-5) + pytorch_out = torch_module(torch_data) + + if isinstance(pytorch_out, tuple): + for i in range(len(pytorch_out)): + actual = gpu_out[i].numpy() + desired = pytorch_out[i].detach().numpy() + np.testing.assert_allclose(actual=actual, desired=desired, rtol=1e-5, atol=1e-5) + else: + actual = gpu_out[0].numpy() + desired = pytorch_out.detach().numpy() + np.testing.assert_allclose(actual=actual, desired=desired, rtol=1e-5, atol=1e-5) @tvm.testing.parametrize_targets("cuda") @@ -202,11 +212,6 @@ def forward(self, x): assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) -@tvm.testing.parametrize_targets("cuda") -def test_copy_fails_(target, dev): - assert False, "test_copy_fails_ indeed fails" - - @tvm.testing.parametrize_targets("cuda") def test_upsample_with_size(target, dev): """ @@ -285,6 +290,28 @@ def forward(self, x): assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module2, target, dev) assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module3, target, dev) +@tvm.testing.parametrize_targets("cuda") +def test_split_size(target, dev): + + channels = 7 + split_size = 3 + dim = 0 # TODO try higher dims! + raw_data = np.random.rand(channels).astype("float32") + + class SplitModelSplitSize(nn.Module): + def __init__(self, split_size, dim): + super().__init__() + self.split_size = split_size + self.dim = dim + + def forward(self, x): + return torch.split(x, split_size_or_sections=self.split_size, dim=self.dim) + + torch_module = SplitModelSplitSize(split_size=split_size, dim=dim).eval() + + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + if __name__ == "__main__": tvm.testing.main() From 75890ce59dfe85f99c6053734295614329335734 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 16 Mar 2025 17:46:05 -0400 Subject: [PATCH 005/105] test sizes and lists --- .../relax/test_from_exported_to_cuda.py | 35 ++++++++++++++++--- 1 file changed, 31 insertions(+), 4 deletions(-) diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index be438cac0e05..482344f2dcb1 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -292,11 +292,14 @@ def forward(self, x): @tvm.testing.parametrize_targets("cuda") def test_split_size(target, dev): - + # Test split using the split_size argument such that it is not a divisor + # of the dimension to split (the last tensor will be smaller) + batch = 2 channels = 7 - split_size = 3 - dim = 0 # TODO try higher dims! - raw_data = np.random.rand(channels).astype("float32") + height, width = 2, 2 + split_size = 3 # last tensor will have just 1 element + dim = 1 # split across channels + raw_data = np.random.rand(batch, channels, height, width).astype("float32") class SplitModelSplitSize(nn.Module): def __init__(self, split_size, dim): @@ -311,6 +314,30 @@ def forward(self, x): assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) +@tvm.testing.parametrize_targets("cuda") +def test_split_sections_list(target, dev): + # Test split using a list of section sizes + batch = 3 + channels = 2 + height = 10 + width = 5 + sections = [3, 2, 5] + dim = 2 # split across height + raw_data = np.random.rand(batch, channels, height, width).astype("float32") + + class SplitModelSectionsList(nn.Module): + def __init__(self, split_size, dim): + super().__init__() + self.split_size = split_size + self.dim = dim + + def forward(self, x): + return torch.split(x, split_size_or_sections=self.split_size, dim=self.dim) + + torch_module = SplitModelSectionsList(split_size=sections, dim=dim).eval() + + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + if __name__ == "__main__": From 5a8eab116c27019d8796a814e1ffbd28a7d72fca Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 16 Mar 2025 17:48:28 -0400 Subject: [PATCH 006/105] just one func --- include/tvm/topi/transform.h | 10 ---------- .../torch/base_fx_graph_translator.py | 20 ------------------- .../torch/exported_program_translator.py | 2 +- 3 files changed, 1 insertion(+), 31 deletions(-) diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index 62f69d9c825e..e459c0d43a81 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -578,16 +578,6 @@ inline Tensor stack(const Array& inputs, int axis = 0, std::string name inline Array split_indices_array(const Tensor& x, Array split_indices, int axis, std::string name = "T_split", std::string tag = kInjective) { - printf("we are in the transform'h's split\n"); - // int x11 = 10; - // int y11 = 5; - // while (y11) { - // x11 -= 2; - // y11--; - // } - // int z11 = 50 / x11; - // printf("z11 is %d\n", z11); - if (axis < 0) { axis += static_cast(x->shape.size()); } diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 40342a830674..6bbc9d5de618 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1021,7 +1021,6 @@ def _scatter(self, node: fx.Node) -> relax.Var: return self.block_builder.emit(relax.op.scatter_elements(x, index, src, axis=dim)) def _split(self, node: fx.Node) -> relax.Var: - """ torch.split with split_size passed as an argument""" x = self.env[node.args[0]] split_size = node.args[1] dim = node.args[2] if len(node.args) > 2 else node.kwargs.get("dim", 0) @@ -1032,25 +1031,6 @@ def _split(self, node: fx.Node) -> relax.Var: n_section.append(s + cum_sum) else: n_section = (self.shape_of(x)[dim].value + split_size - 1) // split_size - print("calling split with n_section", n_section) - print("calling split with dim", dim) - return self.block_builder.emit(relax.op.split(x, n_section, dim)) - - def _split_with_sizes(self, node: fx.Node) -> relax.Var: - """ torch.split with a list of section sizes passed as an argument""" - # TODO - x = self.env[node.args[0]] - split_size = node.args[1] - dim = node.args[2] if len(node.args) > 2 else node.kwargs.get("dim", 0) - if isinstance(split_size, (list, tuple)): - n_section = [] - for s in split_size[:-1]: - cum_sum = 0 if not n_section else n_section[-1] - n_section.append(s + cum_sum) - else: - n_section = (self.shape_of(x)[dim].value + split_size - 1) // split_size - print("calling split with n_section", n_section) - print("calling split with dim", dim) return self.block_builder.emit(relax.op.split(x, n_section, dim)) def _squeeze(self, node: fx.Node) -> relax.Var: diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 834002a02170..2abc0b024871 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -329,7 +329,7 @@ def create_convert_map( "select.int": self._select, "slice.Tensor": self._slice, "split.Tensor": self._split, - "split_with_sizes.default": self._split_with_sizes, + "split_with_sizes.default": self._split, "squeeze.default": self._squeeze, "squeeze.dim": self._squeeze, "take.default": self._take, From c771389462f4620d4ed86194109975226bd5468f Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 16 Mar 2025 17:52:22 -0400 Subject: [PATCH 007/105] cleanup --- include/tvm/topi/transform.h | 1 - python/tvm/relax/op/manipulate.py | 1 - python/tvm/relax/transform/legalize_ops/manipulate.py | 8 -------- src/contrib/msc/framework/tensorrt/transform_tensorrt.cc | 2 -- src/relax/op/tensor/manipulate.cc | 2 -- src/topi/transform.cc | 3 --- tests/python/relax/test_from_exported_to_cuda.py | 2 -- 7 files changed, 19 deletions(-) diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index e459c0d43a81..762148dcfac3 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -972,7 +972,6 @@ inline Tensor strided_slice(const Tensor& x, const Array& begin, const inline Array split_n_sections(const Tensor& x, int num_sections, int axis, std::string name = "T_split_sections", std::string tag = kInjective) { - printf("We are in transform.h's splits_sections\n"); if (axis < 0) { axis += static_cast(x->shape.size()); } diff --git a/python/tvm/relax/op/manipulate.py b/python/tvm/relax/op/manipulate.py index 85f08b7f98bd..0f6e537ab3d6 100644 --- a/python/tvm/relax/op/manipulate.py +++ b/python/tvm/relax/op/manipulate.py @@ -253,7 +253,6 @@ def split( """ if isinstance(indices_or_sections, int): indices_or_sections = IntImm("int64", indices_or_sections) - print("CALLING _ffi_api.split !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") return _ffi_api.split(x, indices_or_sections, axis) # type: ignore diff --git a/python/tvm/relax/transform/legalize_ops/manipulate.py b/python/tvm/relax/transform/legalize_ops/manipulate.py index cefbba55fce3..55e3aeb4423c 100644 --- a/python/tvm/relax/transform/legalize_ops/manipulate.py +++ b/python/tvm/relax/transform/legalize_ops/manipulate.py @@ -107,19 +107,11 @@ def _permute_dims(bb: BlockBuilder, call: Call) -> Expr: @register_legalize("relax.split") def _split(bb: BlockBuilder, call: Call) -> Expr: - print("is this the split op??????????????????????????????????/") if isinstance(call.attrs.indices_or_sections, tir.IntImm): indices_or_sections = call.attrs.indices_or_sections.value modulo = tvm.arith.Analyzer().simplify( call.args[0].struct_info.shape.values[call.attrs.axis] % indices_or_sections ) - # if isinstance(modulo, tir.IntImm): - # if modulo != 0: - # logging.info( - # "Split cannot be legalized by TOPI when the axis being split has " - # "length that not divisible by the input number of section." - # ) - # return call else: indices_or_sections = call.attrs.indices_or_sections return bb.call_te(topi.split, call.args[0], indices_or_sections, call.attrs.axis) diff --git a/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc b/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc index 92d090946cdb..94dfec7ea621 100644 --- a/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc +++ b/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc @@ -782,8 +782,6 @@ Expr RewriteShapeLike(BlockBuilder builder, const Var& var, const Call& src_call Expr RewriteSplit(BlockBuilder builder, const Var& var, const Call& src_call, const Map& new_calls, const String& config) { - printf("INSIDE REWRITESPLIT !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!11\n"); - std::cout << "Why is this not building?" << std::endl; const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; const auto& input_shape = ExprUtils::GetShape(call->args[0]); const auto* src_attrs = src_call->attrs.as(); diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index f40f0ac36e0a..6489ce36b47b 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -827,7 +827,6 @@ TVM_REGISTER_NODE_TYPE(SplitAttrs); #include Expr split(Expr x, Variant> indices_or_sections, int axis) { - printf("INSIDE SPLIT!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n"); ObjectPtr attrs = make_object(); if (const auto* indices = indices_or_sections.as()) { for (int i = 0; i < static_cast(indices->size()); ++i) { @@ -851,7 +850,6 @@ Expr split(Expr x, Variant> indices_or_sections, int axis) attrs->axis = axis; static const Op& op = Op::Get("relax.split"); - printf("CALLING OP !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n"); return Call(op, {std::move(x)}, Attrs(attrs), {}); } diff --git a/src/topi/transform.cc b/src/topi/transform.cc index a0df64faa291..7ef63a9b3f56 100644 --- a/src/topi/transform.cc +++ b/src/topi/transform.cc @@ -82,10 +82,7 @@ TVM_REGISTER_GLOBAL("topi.ndarray_size").set_body([](TVMArgs args, TVMRetValue* *rv = ndarray_size(args[0], args[1]); }); -#include TVM_REGISTER_GLOBAL("topi.split").set_body([](TVMArgs args, TVMRetValue* rv) { - printf("we are in transform.cc's topi.split, called by _split python func in manipulate.py\n"); - if (args[1].type_code() == kDLInt || args[1].type_code() == kDLUInt) { *rv = split_n_sections(args[0], args[1], args[2]); } else { diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index 482344f2dcb1..d629d3fc87d6 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -15,8 +15,6 @@ # specific language governing permissions and limitations # under the License. -import sys -sys.path.append('/ssd1/htalendr/tvm/python') # Refer to local TVM build import tvm from tvm import relax From 2fbe4c1cc03e4e31369394741d2e149899f9c9d6 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 16 Mar 2025 17:56:57 -0400 Subject: [PATCH 008/105] no assert --- src/relax/op/tensor/manipulate.cc | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index 6489ce36b47b..cb738db363ee 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -824,8 +824,6 @@ TVM_REGISTER_OP("relax.reshape") /* relax.split */ TVM_REGISTER_NODE_TYPE(SplitAttrs); -#include - Expr split(Expr x, Variant> indices_or_sections, int axis) { ObjectPtr attrs = make_object(); if (const auto* indices = indices_or_sections.as()) { From e5095b8dec610e78c4d7435b05b6d4c6727402ee Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 16 Mar 2025 18:01:10 -0400 Subject: [PATCH 009/105] linting --- tests/python/relax/test_from_exported_to_cuda.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index d629d3fc87d6..c120eb89811c 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -288,15 +288,16 @@ def forward(self, x): assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module2, target, dev) assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module3, target, dev) + @tvm.testing.parametrize_targets("cuda") def test_split_size(target, dev): - # Test split using the split_size argument such that it is not a divisor + # Test split using the split_size argument such that it is not a divisor # of the dimension to split (the last tensor will be smaller) batch = 2 channels = 7 height, width = 2, 2 - split_size = 3 # last tensor will have just 1 element - dim = 1 # split across channels + split_size = 3 # last tensor will have just 1 element + dim = 1 # split across channels raw_data = np.random.rand(batch, channels, height, width).astype("float32") class SplitModelSplitSize(nn.Module): @@ -312,6 +313,7 @@ def forward(self, x): assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + @tvm.testing.parametrize_targets("cuda") def test_split_sections_list(target, dev): # Test split using a list of section sizes @@ -320,7 +322,7 @@ def test_split_sections_list(target, dev): height = 10 width = 5 sections = [3, 2, 5] - dim = 2 # split across height + dim = 2 # split across height raw_data = np.random.rand(batch, channels, height, width).astype("float32") class SplitModelSectionsList(nn.Module): @@ -337,6 +339,5 @@ def forward(self, x): assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) - if __name__ == "__main__": tvm.testing.main() From 490a454c20f429fe54425182fde49bd5ccedf4a6 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 16 Mar 2025 18:19:32 -0400 Subject: [PATCH 010/105] chunk --- .../torch/base_fx_graph_translator.py | 8 ++++++ .../torch/exported_program_translator.py | 1 + .../relax/test_from_exported_to_cuda.py | 25 +++++++++++++++++++ 3 files changed, 34 insertions(+) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 6bbc9d5de618..90598640aa21 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -935,6 +935,14 @@ def _cat(self, node: fx.Node) -> relax.Var: axis = args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0) return self.block_builder.emit(relax.op.concat(args[0], axis=axis)) + def _chunk(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + chunks = node.args[1] + dim = node.args[2] if len(node.args) > 2 else node.kwargs.get("dim", 0) + length_dim = int(self.shape_of(x)[dim]) + n_section = math.ceil(length_dim / chunks) + return self.block_builder.emit(relax.op.split(x, n_section, dim)) + def _cumsum(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 2abc0b024871..23d4063748f2 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -316,6 +316,7 @@ def create_convert_map( "argmin.default": self._argmax_argmin(relax.op.argmin), # tensor manipulation "cat.default": self._cat, + "chunk.default": self._chunk, "clamp.Tensor": self._clamp, "concat.default": self._cat, "copy_.default": self._copy_, diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index c120eb89811c..d834b54b51a1 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -15,6 +15,10 @@ # specific language governing permissions and limitations # under the License. +import sys + +sys.path.append("/ssd1/htalendr/tvm/python") # Refer to local TVM build + import tvm from tvm import relax @@ -336,6 +340,27 @@ def forward(self, x): torch_module = SplitModelSectionsList(split_size=sections, dim=dim).eval() + +@tvm.testing.parametrize_targets("cuda") +def test_chunk(target, dev): + batch = 3 + channels = 5 + height = 7 + width = 11 + chunks = 2 + dim = 1 + raw_data = np.random.rand(batch, channels, height, width).astype("float32") + + class ChunkModel(nn.Module): + def __init__(self, chunks, dim): + super().__init__() + self.chunks = chunks + self.dim = dim + + def forward(self, x): + return x.chunk(self.chunks, dim=self.dim) + + torch_module = ChunkModel(chunks=chunks, dim=dim).eval() assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) From ec7311cdb03f86686f19412a20acb8d4eed41fcc Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Thu, 20 Mar 2025 08:51:06 -0400 Subject: [PATCH 011/105] remove unsused modulo --- python/tvm/relax/transform/legalize_ops/manipulate.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/python/tvm/relax/transform/legalize_ops/manipulate.py b/python/tvm/relax/transform/legalize_ops/manipulate.py index 55e3aeb4423c..662d4e946b5f 100644 --- a/python/tvm/relax/transform/legalize_ops/manipulate.py +++ b/python/tvm/relax/transform/legalize_ops/manipulate.py @@ -16,7 +16,6 @@ # under the License. # pylint: disable=invalid-name """Default legalization function for manipulate operators.""" -import logging from typing import Optional import tvm @@ -109,9 +108,6 @@ def _permute_dims(bb: BlockBuilder, call: Call) -> Expr: def _split(bb: BlockBuilder, call: Call) -> Expr: if isinstance(call.attrs.indices_or_sections, tir.IntImm): indices_or_sections = call.attrs.indices_or_sections.value - modulo = tvm.arith.Analyzer().simplify( - call.args[0].struct_info.shape.values[call.attrs.axis] % indices_or_sections - ) else: indices_or_sections = call.attrs.indices_or_sections return bb.call_te(topi.split, call.args[0], indices_or_sections, call.attrs.axis) From 0744701366475bc876dc68aff51d1a612095b48b Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Thu, 20 Mar 2025 11:13:34 -0400 Subject: [PATCH 012/105] fixed first test --- .../test_transform_legalize_ops_manipulate.py | 139 +++++++++++------- 1 file changed, 86 insertions(+), 53 deletions(-) diff --git a/tests/python/relax/test_transform_legalize_ops_manipulate.py b/tests/python/relax/test_transform_legalize_ops_manipulate.py index 0565b7a5790a..db1b5c08ceca 100644 --- a/tests/python/relax/test_transform_legalize_ops_manipulate.py +++ b/tests/python/relax/test_transform_legalize_ops_manipulate.py @@ -15,6 +15,9 @@ # specific language governing permissions and limitations # under the License. +import sys +sys.path.append('/ssd1/htalendr/tvm/python') + import tvm from tvm import relax from tvm.relax.transform import LegalizeOps @@ -788,102 +791,132 @@ def test_split_by_indices_n_section_indivisible(): class Split: @R.function def main(x: R.Tensor((2, 10, 4), "float32")) -> R.Tuple([R.Tensor((2, 4, 4), "float32"), R.Tensor((2, 4, 4), "float32"), R.Tensor((2, 2, 4), "float32")]): - gv: R.Tuple([R.Tensor((2, 4, 4), "float32"), R.Tensor((2, 4, 4), "float32"), R.Tensor((2, 2, 4), "float32")]) = R.split(x, 3, axis=1) - return gv - # fmt: on - - mod = LegalizeOps()(Split) - tvm.ir.assert_structural_equal(mod, Split) - - -def test_split_by_indices_n_section_divisible(): - # fmt: off - @tvm.script.ir_module - class Split: - @R.function - def main(x: R.Tensor((2, 10, 4), "float32")) -> R.Tuple([R.Tensor((2, 5, 4), "float32"), R.Tensor((2, 5, 4), "float32")]): - gv: R.Tuple([R.Tensor((2, 5, 4), "float32"), R.Tensor((2, 5, 4), "float32")]) = R.split(x, 2, axis=1) + gv: R.Tuple([R.Tensor((2, 4, 4), "float32"), R.Tensor((2, 4, 4), "float32"), R.Tensor((2, 2, 4), "float32")]) = R.split(x, indices_or_sections=3, axis=1) return gv @tvm.script.ir_module class Expected: @R.function - def main(x: R.Tensor((2, 10, 4), "float32")) -> R.Tuple([R.Tensor((2, 5, 4), "float32"), R.Tensor((2, 5, 4), "float32")]): - gv = R.call_tir(Expected.split, (x,), [R.Tensor((2, 5, 4), "float32"), R.Tensor((2, 5, 4), "float32")]) + def main(x: R.Tensor((2, 10, 4), "float32")) -> R.Tuple([R.Tensor((2, 4, 4), "float32"), R.Tensor((2, 4, 4), "float32"), R.Tensor((2, 2, 4), "float32")]): + gv = R.call_tir(Expected.split, (x,), [R.Tensor((2, 4, 4), "float32"), R.Tensor((2, 4, 4), "float32"), R.Tensor((2, 2, 4), "float32")]) return gv @T.prim_func(private=True) - def split(rxplaceholder: T.Buffer((T.int64(2), T.int64(10), T.int64(4)), "float32"), T_split_sections: T.Buffer((T.int64(2), T.int64(5), T.int64(4)), "float32"), T_split_sections_1: T.Buffer((T.int64(2), T.int64(5), T.int64(4)), "float32")): + def split(rxplaceholder: T.Buffer((T.int64(2), T.int64(10), T.int64(4)), "float32"), T_split_sections: T.Buffer((T.int64(2), T.int64(4), T.int64(4)), "float32"), T_split_sections_1: T.Buffer((T.int64(2), T.int64(4), T.int64(4)), "float32"), T_split_sections_2: T.Buffer((T.int64(2), T.int64(2), T.int64(4)), "float32")): T.func_attr({"tir.noalias": True}) - for i0, i1, i2 in T.grid(T.int64(2), T.int64(5), T.int64(4)): + for i0, i1, i2 in T.grid(T.int64(2), T.int64(4), T.int64(4)): with T.block("T_split_sections"): ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2]) T.reads(rxplaceholder[ax0, ax1, ax2]) T.writes(T_split_sections[ax0, ax1, ax2]) T_split_sections[ax0, ax1, ax2] = rxplaceholder[ax0, ax1, ax2] - for i0, i1, i2 in T.grid(T.int64(2), T.int64(5), T.int64(4)): + for i0, i1, i2 in T.grid(T.int64(2), T.int64(4), T.int64(4)): with T.block("T_split_sections_1"): ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2]) - T.reads(rxplaceholder[ax0, ax1 + T.int64(5), ax2]) + T.reads(rxplaceholder[ax0, ax1 + T.int64(4), ax2]) T.writes(T_split_sections_1[ax0, ax1, ax2]) - T_split_sections_1[ax0, ax1, ax2] = rxplaceholder[ax0, ax1 + T.int64(5), ax2] + T_split_sections_1[ax0, ax1, ax2] = rxplaceholder[ax0, ax1 + T.int64(4), ax2] + for i0, i1, i2 in T.grid(T.int64(2), T.int64(2), T.int64(4)): + with T.block("T_split_sections_2"): + ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(rxplaceholder[ax0, ax1 + T.int64(8), ax2]) + T.writes(T_split_sections_2[ax0, ax1, ax2]) + T_split_sections_2[ax0, ax1, ax2] = rxplaceholder[ax0, ax1 + T.int64(8), ax2] + # fmt: on mod = LegalizeOps()(Split) tvm.ir.assert_structural_equal(mod, Expected) -def test_split_by_indices_n_section_divisible_symbolic(): +def test_split_by_indices_n_section_divisible(): # fmt: off @tvm.script.ir_module class Split: @R.function - def main(dumb_param: R.Tensor(("n",)), x: R.Tensor(("m", "n * 3"), "float32")) -> R.Tuple([R.Tensor(("m", "n"), "float32"), R.Tensor(("m", "n"), "float32"), R.Tensor(("m", "n"), "float32")]): - m = T.int64() - n = T.int64() - gv: R.Tuple([R.Tensor((m, n), "float32"), R.Tensor((m, n), "float32"), R.Tensor((m, n), "float32")]) = R.split(x, 3, axis=1) + def main(x: R.Tensor((2, 10, 4), "float32")) -> R.Tuple([R.Tensor((2, 5, 4), "float32"), R.Tensor((2, 5, 4), "float32")]): + gv: R.Tuple([R.Tensor((2, 5, 4), "float32"), R.Tensor((2, 5, 4), "float32")]) = R.split(x, 2, axis=1) return gv @tvm.script.ir_module class Expected: @R.function - def main(dumb_param: R.Tensor(("n",)), x: R.Tensor(("m", "(n * 3)"), "float32")) -> R.Tuple(R.Tensor(("m", "((n * 3) // 3)"), "float32"), R.Tensor(("m", "((((n * 3) // 3) * 2) - ((n * 3) // 3))"), "float32"), R.Tensor(("m", "((n * 3) - (((n * 3) // 3) * 2))"), "float32")): - m = T.int64() - n = T.int64() - gv = R.call_tir(Expected.split, (x,), [R.Tensor((m, ((n * 3) // 3)), "float32"), R.Tensor((m, ((((n * 3) // 3) * 2) - ((n * 3) // 3))), "float32"), R.Tensor((m, ((n * 3) - (((n * 3) // 3) * 2))), "float32")], tir_vars=(n,)) + def main(x: R.Tensor((2, 10, 4), "float32")) -> R.Tuple([R.Tensor((2, 5, 4), "float32"), R.Tensor((2, 5, 4), "float32")]): + gv = R.call_tir(Expected.split, (x,), [R.Tensor((2, 5, 4), "float32"), R.Tensor((2, 5, 4), "float32")]) return gv @T.prim_func(private=True) - def split(var_rxplaceholder: T.handle, var_T_split_sections: T.handle, var_T_split_sections_1: T.handle, var_T_split_sections_2: T.handle, n: T.int64): + def split(rxplaceholder: T.Buffer((T.int64(2), T.int64(10), T.int64(4)), "float32"), T_split_sections: T.Buffer((T.int64(2), T.int64(5), T.int64(4)), "float32"), T_split_sections_1: T.Buffer((T.int64(2), T.int64(5), T.int64(4)), "float32")): T.func_attr({"tir.noalias": True}) - m = T.int64() - rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n * T.int64(3)], dtype="float32") - T_split_sections = T.match_buffer(var_T_split_sections, [m, n * T.int64(3) // T.int64(3)], dtype="float32") - T_split_sections_1 = T.match_buffer(var_T_split_sections_1, [m, n * T.int64(3) // T.int64(3) * T.int64(2) - n * T.int64(3) // T.int64(3)], dtype="float32") - T_split_sections_2 = T.match_buffer(var_T_split_sections_2, [m, n * T.int64(3) - n * T.int64(3) // T.int64(3) * T.int64(2)], dtype="float32") - for i0, i1 in T.grid(m, n): + for i0, i1, i2 in T.grid(T.int64(2), T.int64(5), T.int64(4)): with T.block("T_split_sections"): - ax0, ax1 = T.axis.remap("SS", [i0, i1]) - T.reads(rxplaceholder[ax0, ax1]) - T.writes(T_split_sections[ax0, ax1]) - T_split_sections[ax0, ax1] = rxplaceholder[ax0, ax1] - for i0, i1 in T.grid(m, n): + ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(rxplaceholder[ax0, ax1, ax2]) + T.writes(T_split_sections[ax0, ax1, ax2]) + T_split_sections[ax0, ax1, ax2] = rxplaceholder[ax0, ax1, ax2] + for i0, i1, i2 in T.grid(T.int64(2), T.int64(5), T.int64(4)): with T.block("T_split_sections_1"): - ax0, ax1 = T.axis.remap("SS", [i0, i1]) - T.reads(rxplaceholder[ax0, n + ax1]) - T.writes(T_split_sections_1[ax0, ax1]) - T_split_sections_1[ax0, ax1] = rxplaceholder[ax0, n + ax1] - for i0, i1 in T.grid(m, n): - with T.block("T_split_sections_2"): - ax0, ax1 = T.axis.remap("SS", [i0, i1]) - T.reads(rxplaceholder[ax0, n * T.int64(2) + ax1]) - T.writes(T_split_sections_2[ax0, ax1]) - T_split_sections_2[ax0, ax1] = rxplaceholder[ax0, n * T.int64(2) + ax1] + ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(rxplaceholder[ax0, ax1 + T.int64(5), ax2]) + T.writes(T_split_sections_1[ax0, ax1, ax2]) + T_split_sections_1[ax0, ax1, ax2] = rxplaceholder[ax0, ax1 + T.int64(5), ax2] # fmt: on mod = LegalizeOps()(Split) tvm.ir.assert_structural_equal(mod, Expected) +# TODO uncomment +# def test_split_by_indices_n_section_divisible_symbolic(): +# # fmt: off +# @tvm.script.ir_module +# class Split: +# @R.function +# def main(dumb_param: R.Tensor(("n",)), x: R.Tensor(("m", "n * 3"), "float32")) -> R.Tuple([R.Tensor(("m", "n"), "float32"), R.Tensor(("m", "n"), "float32"), R.Tensor(("m", "n"), "float32")]): +# m = T.int64() +# n = T.int64() +# gv: R.Tuple([R.Tensor((m, n), "float32"), R.Tensor((m, n), "float32"), R.Tensor((m, n), "float32")]) = R.split(x, 3, axis=1) +# return gv + +# @tvm.script.ir_module +# class Expected: +# @R.function +# def main(dumb_param: R.Tensor(("n",)), x: R.Tensor(("m", "(n * 3)"), "float32")) -> R.Tuple(R.Tensor(("m", "((n * 3) // 3)"), "float32"), R.Tensor(("m", "((((n * 3) // 3) * 2) - ((n * 3) // 3))"), "float32"), R.Tensor(("m", "((n * 3) - (((n * 3) // 3) * 2))"), "float32")): +# m = T.int64() +# n = T.int64() +# gv = R.call_tir(Expected.split, (x,), [R.Tensor((m, ((n * 3) // 3)), "float32"), R.Tensor((m, ((((n * 3) // 3) * 2) - ((n * 3) // 3))), "float32"), R.Tensor((m, ((n * 3) - (((n * 3) // 3) * 2))), "float32")], tir_vars=(n,)) +# return gv + +# @T.prim_func(private=True) +# def split(var_rxplaceholder: T.handle, var_T_split_sections: T.handle, var_T_split_sections_1: T.handle, var_T_split_sections_2: T.handle, n: T.int64): +# T.func_attr({"tir.noalias": True}) +# m = T.int64() +# rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n * T.int64(3)], dtype="float32") +# T_split_sections = T.match_buffer(var_T_split_sections, [m, n * T.int64(3) // T.int64(3)], dtype="float32") +# T_split_sections_1 = T.match_buffer(var_T_split_sections_1, [m, n * T.int64(3) // T.int64(3) * T.int64(2) - n * T.int64(3) // T.int64(3)], dtype="float32") +# T_split_sections_2 = T.match_buffer(var_T_split_sections_2, [m, n * T.int64(3) - n * T.int64(3) // T.int64(3) * T.int64(2)], dtype="float32") +# for i0, i1 in T.grid(m, n): +# with T.block("T_split_sections"): +# ax0, ax1 = T.axis.remap("SS", [i0, i1]) +# T.reads(rxplaceholder[ax0, ax1]) +# T.writes(T_split_sections[ax0, ax1]) +# T_split_sections[ax0, ax1] = rxplaceholder[ax0, ax1] +# for i0, i1 in T.grid(m, n): +# with T.block("T_split_sections_1"): +# ax0, ax1 = T.axis.remap("SS", [i0, i1]) +# T.reads(rxplaceholder[ax0, n + ax1]) +# T.writes(T_split_sections_1[ax0, ax1]) +# T_split_sections_1[ax0, ax1] = rxplaceholder[ax0, n + ax1] +# for i0, i1 in T.grid(m, n): +# with T.block("T_split_sections_2"): +# ax0, ax1 = T.axis.remap("SS", [i0, i1]) +# T.reads(rxplaceholder[ax0, n * T.int64(2) + ax1]) +# T.writes(T_split_sections_2[ax0, ax1]) +# T_split_sections_2[ax0, ax1] = rxplaceholder[ax0, n * T.int64(2) + ax1] +# # fmt: on + +# mod = LegalizeOps()(Split) +# tvm.ir.assert_structural_equal(mod, Expected) + def test_squeeze(): # fmt: off From 40f171183e9a8194781ebe2d7592ea377756b0bd Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Thu, 20 Mar 2025 11:34:23 -0400 Subject: [PATCH 013/105] fixed second test and lint --- .../test_transform_legalize_ops_manipulate.py | 105 +++++++++--------- 1 file changed, 53 insertions(+), 52 deletions(-) diff --git a/tests/python/relax/test_transform_legalize_ops_manipulate.py b/tests/python/relax/test_transform_legalize_ops_manipulate.py index db1b5c08ceca..4836ffd01041 100644 --- a/tests/python/relax/test_transform_legalize_ops_manipulate.py +++ b/tests/python/relax/test_transform_legalize_ops_manipulate.py @@ -16,7 +16,8 @@ # under the License. import sys -sys.path.append('/ssd1/htalendr/tvm/python') + +sys.path.append("/ssd1/htalendr/tvm/python") import tvm from tvm import relax @@ -865,57 +866,57 @@ def split(rxplaceholder: T.Buffer((T.int64(2), T.int64(10), T.int64(4)), "float3 mod = LegalizeOps()(Split) tvm.ir.assert_structural_equal(mod, Expected) -# TODO uncomment -# def test_split_by_indices_n_section_divisible_symbolic(): -# # fmt: off -# @tvm.script.ir_module -# class Split: -# @R.function -# def main(dumb_param: R.Tensor(("n",)), x: R.Tensor(("m", "n * 3"), "float32")) -> R.Tuple([R.Tensor(("m", "n"), "float32"), R.Tensor(("m", "n"), "float32"), R.Tensor(("m", "n"), "float32")]): -# m = T.int64() -# n = T.int64() -# gv: R.Tuple([R.Tensor((m, n), "float32"), R.Tensor((m, n), "float32"), R.Tensor((m, n), "float32")]) = R.split(x, 3, axis=1) -# return gv - -# @tvm.script.ir_module -# class Expected: -# @R.function -# def main(dumb_param: R.Tensor(("n",)), x: R.Tensor(("m", "(n * 3)"), "float32")) -> R.Tuple(R.Tensor(("m", "((n * 3) // 3)"), "float32"), R.Tensor(("m", "((((n * 3) // 3) * 2) - ((n * 3) // 3))"), "float32"), R.Tensor(("m", "((n * 3) - (((n * 3) // 3) * 2))"), "float32")): -# m = T.int64() -# n = T.int64() -# gv = R.call_tir(Expected.split, (x,), [R.Tensor((m, ((n * 3) // 3)), "float32"), R.Tensor((m, ((((n * 3) // 3) * 2) - ((n * 3) // 3))), "float32"), R.Tensor((m, ((n * 3) - (((n * 3) // 3) * 2))), "float32")], tir_vars=(n,)) -# return gv - -# @T.prim_func(private=True) -# def split(var_rxplaceholder: T.handle, var_T_split_sections: T.handle, var_T_split_sections_1: T.handle, var_T_split_sections_2: T.handle, n: T.int64): -# T.func_attr({"tir.noalias": True}) -# m = T.int64() -# rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n * T.int64(3)], dtype="float32") -# T_split_sections = T.match_buffer(var_T_split_sections, [m, n * T.int64(3) // T.int64(3)], dtype="float32") -# T_split_sections_1 = T.match_buffer(var_T_split_sections_1, [m, n * T.int64(3) // T.int64(3) * T.int64(2) - n * T.int64(3) // T.int64(3)], dtype="float32") -# T_split_sections_2 = T.match_buffer(var_T_split_sections_2, [m, n * T.int64(3) - n * T.int64(3) // T.int64(3) * T.int64(2)], dtype="float32") -# for i0, i1 in T.grid(m, n): -# with T.block("T_split_sections"): -# ax0, ax1 = T.axis.remap("SS", [i0, i1]) -# T.reads(rxplaceholder[ax0, ax1]) -# T.writes(T_split_sections[ax0, ax1]) -# T_split_sections[ax0, ax1] = rxplaceholder[ax0, ax1] -# for i0, i1 in T.grid(m, n): -# with T.block("T_split_sections_1"): -# ax0, ax1 = T.axis.remap("SS", [i0, i1]) -# T.reads(rxplaceholder[ax0, n + ax1]) -# T.writes(T_split_sections_1[ax0, ax1]) -# T_split_sections_1[ax0, ax1] = rxplaceholder[ax0, n + ax1] -# for i0, i1 in T.grid(m, n): -# with T.block("T_split_sections_2"): -# ax0, ax1 = T.axis.remap("SS", [i0, i1]) -# T.reads(rxplaceholder[ax0, n * T.int64(2) + ax1]) -# T.writes(T_split_sections_2[ax0, ax1]) -# T_split_sections_2[ax0, ax1] = rxplaceholder[ax0, n * T.int64(2) + ax1] -# # fmt: on - -# mod = LegalizeOps()(Split) -# tvm.ir.assert_structural_equal(mod, Expected) + +def test_split_by_indices_n_section_divisible_symbolic(): + # fmt: off + @tvm.script.ir_module + class Split: + @R.function + def main(dumb_param: R.Tensor(("n",)), x: R.Tensor(("m", "n * 3"), "float32")) -> R.Tuple([R.Tensor(("m", "n"), "float32"), R.Tensor(("m", "n"), "float32"), R.Tensor(("m", "n"), "float32")]): + m = T.int64() + n = T.int64() + gv: R.Tuple([R.Tensor((m, n), "float32"), R.Tensor((m, n), "float32"), R.Tensor((m, n), "float32")]) = R.split(x, 3, axis=1) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(dumb_param: R.Tensor(("n",)), x: R.Tensor(("m", "(n * 3)"), "float32")) -> R.Tuple(R.Tensor(("m", "((n * 3) // 3)"), "float32"), R.Tensor(("m", "((((n * 3) // 3) * 2) - ((n * 3) // 3))"), "float32"), R.Tensor(("m", "((n * 3) - (((n * 3) // 3) * 2))"), "float32")): + m = T.int64() + n = T.int64() + gv = R.call_tir(Expected.split, (x,), [R.Tensor((m, ((n * 3 + 3 - 1) // 3)), "float32"), R.Tensor((m, ((((n * 3 + 3 - 1) // 3) * 2) - ((n * 3 + 3 - 1) // 3))), "float32"), R.Tensor((m, ((n * 3) - (((n * 3 + 3 - 1) // 3) * 2))), "float32")], tir_vars=R.shape([n])) + return gv + + @T.prim_func(private=True) + def split(var_rxplaceholder: T.handle, var_T_split_sections: T.handle, var_T_split_sections_1: T.handle, var_T_split_sections_2: T.handle, n: T.int64): + T.func_attr({"tir.noalias": True}) + m = T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n * T.int64(3)], dtype="float32") + T_split_sections = T.match_buffer(var_T_split_sections, [m, (n * T.int64(3) + T.int64(3) - T.int64(1)) // T.int64(3)], dtype="float32") + T_split_sections_1 = T.match_buffer(var_T_split_sections_1, [m, (n * T.int64(3) + T.int64(3) - T.int64(1)) // T.int64(3) * T.int64(2) - (n * T.int64(3) + T.int64(3) - T.int64(1)) // T.int64(3)], dtype="float32") + T_split_sections_2 = T.match_buffer(var_T_split_sections_2, [m, n * T.int64(3) - (n * T.int64(3) + T.int64(3) - T.int64(1)) // T.int64(3) * T.int64(2)], dtype="float32") + for i0, i1 in T.grid(m, n): + with T.block("T_split_sections"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[ax0, ax1]) + T.writes(T_split_sections[ax0, ax1]) + T_split_sections[ax0, ax1] = rxplaceholder[ax0, ax1] + for i0, i1 in T.grid(m, n): + with T.block("T_split_sections_1"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[ax0, ax1 + n]) + T.writes(T_split_sections_1[ax0, ax1]) + T_split_sections_1[ax0, ax1] = rxplaceholder[ax0, ax1 + n] + for i0, i1 in T.grid(m, n): + with T.block("T_split_sections_2"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[ax0, n * T.int64(2) + ax1]) + T.writes(T_split_sections_2[ax0, ax1]) + T_split_sections_2[ax0, ax1] = rxplaceholder[ax0, n * T.int64(2) + ax1] + # fmt: on + + mod = LegalizeOps()(Split) + tvm.ir.assert_structural_equal(mod, Expected) def test_squeeze(): From 4582c3a7e94e6dbe5c8a86e351a3e5e8ec8f505c Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Mon, 24 Mar 2025 12:25:20 -0400 Subject: [PATCH 014/105] linting --- tests/python/relax/test_from_exported_to_cuda.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index c28a2ff63f1f..12b50b3886f3 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -335,6 +335,7 @@ def forward(self, x): torch_module = SplitModelSectionsList(split_size=sections, dim=dim).eval() + @tvm.testing.parametrize_targets("cuda") def test_chunk(target, dev): batch = 3 From afa793a9fe791ae729d71ac895443820e588fa5e Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Mon, 24 Mar 2025 12:25:59 -0400 Subject: [PATCH 015/105] fix one test --- tests/python/relax/test_from_exported_to_cuda.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index 12b50b3886f3..5c7f30d14819 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -334,6 +334,7 @@ def forward(self, x): return torch.split(x, split_size_or_sections=self.split_size, dim=self.dim) torch_module = SplitModelSectionsList(split_size=sections, dim=dim).eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) @tvm.testing.parametrize_targets("cuda") From d71b51855b9053df98bc0f99b09af85d78b8085d Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Mon, 24 Mar 2025 14:20:06 -0400 Subject: [PATCH 016/105] chunk not passing anymore --- tests/python/relax/test_from_exported_to_cuda.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index 5c7f30d14819..8b788bf1628e 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -341,9 +341,9 @@ def forward(self, x): def test_chunk(target, dev): batch = 3 channels = 5 - height = 7 - width = 11 - chunks = 2 + height = 2 + width = 7 + chunks = 11 dim = 1 raw_data = np.random.rand(batch, channels, height, width).astype("float32") From e95aef67053c1093f2267c064beae4e1ca2beff9 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Mon, 24 Mar 2025 14:23:34 -0400 Subject: [PATCH 017/105] get_item error --- tests/python/relax/test_from_exported_to_cuda.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index 8b788bf1628e..11e55919e139 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -14,6 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import sys +sys.path.append('/ssd1/htalendr/tvm/python') import tvm from tvm import relax From bc504461c8c072b0455d2871b1b73913bfe9d2ed Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Mon, 24 Mar 2025 14:55:46 -0400 Subject: [PATCH 018/105] chunk unit tests --- .../torch/base_fx_graph_translator.py | 9 ++- .../relax/test_from_exported_to_cuda.py | 59 +++++++++++++++++-- 2 files changed, 59 insertions(+), 9 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 176b42ce997b..2d73c8dced5e 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -960,9 +960,12 @@ def _chunk(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] chunks = node.args[1] dim = node.args[2] if len(node.args) > 2 else node.kwargs.get("dim", 0) - length_dim = int(self.shape_of(x)[dim]) - n_section = math.ceil(length_dim / chunks) - return self.block_builder.emit(relax.op.split(x, n_section, dim)) + x_shape = self.shape_of(x) + max_chunks = x_shape[dim].value + n_sections = min(chunks, max_chunks) + return self.block_builder.emit( + relax.op.split(x=x, indices_or_sections=n_sections, axis=dim) + ) def _cumsum(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index 11e55919e139..77d0db0b4a99 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -14,8 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import sys -sys.path.append('/ssd1/htalendr/tvm/python') import tvm from tvm import relax @@ -340,12 +338,61 @@ def forward(self, x): @tvm.testing.parametrize_targets("cuda") -def test_chunk(target, dev): - batch = 3 +def test_chunk_even(target, dev): + # Chunks is a divisor of the dimension size + batch = 6 + channels = 2 + height = 3 + width = 4 + chunks = 3 + dim = 0 + raw_data = np.random.rand(batch, channels, height, width).astype("float32") + + class ChunkModel(nn.Module): + def __init__(self, chunks, dim): + super().__init__() + self.chunks = chunks + self.dim = dim + + def forward(self, x): + return x.chunk(self.chunks, dim=self.dim) + + torch_module = ChunkModel(chunks=chunks, dim=dim).eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + +@tvm.testing.parametrize_targets("cuda") +def test_chunk_uneven(target, dev): + # Chunks is not a divisor of the dimension size + batch = 2 channels = 5 + height = 4 + width = 5 + chunks = 2 + dim = 1 + raw_data = np.random.rand(batch, channels, height, width).astype("float32") + + class ChunkModel(nn.Module): + def __init__(self, chunks, dim): + super().__init__() + self.chunks = chunks + self.dim = dim + + def forward(self, x): + return x.chunk(self.chunks, dim=self.dim) + + torch_module = ChunkModel(chunks=chunks, dim=dim).eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + +@tvm.testing.parametrize_targets("cuda") +def test_chunk_too_many(target, dev): + # If user asks for more chunks than the size of the dim, pytorch simply splits in sections of size 1 + batch = 1 + channels = 3 height = 2 - width = 7 - chunks = 11 + width = 2 + chunks = 99 dim = 1 raw_data = np.random.rand(batch, channels, height, width).astype("float32") From db5ec018e3a3869496029dd84237934453186995 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Fri, 28 Mar 2025 14:48:48 -0400 Subject: [PATCH 019/105] index select test passes --- .../frontend/torch/base_fx_graph_translator.py | 6 ++++++ .../frontend/torch/exported_program_translator.py | 4 +++- python/tvm/relax/frontend/torch/fx_translator.py | 6 ------ tests/python/relax/test_from_exported_to_cuda.py | 13 +++++++++++++ 4 files changed, 22 insertions(+), 7 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 71554a8a5bab..aa8150fac71c 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1193,6 +1193,12 @@ def _fill(self, node: fx.Node) -> relax.Var: value = args[1] if isinstance(args[1], relax.Expr) else relax.const(args[1], dtype) return self.block_builder.emit(relax.op.full(x.struct_info.shape, value, dtype)) + def _index_select(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + dim = node.args[1] + index = self.env[node.args[2]] + return self.block_builder.emit(relax.op.take(x, index, dim)) + def _new_ones(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) self_var = args[0] diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 4319fbebe74a..a7b3096b4a9e 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -386,7 +386,6 @@ def create_convert_map( "reshape.default": self._reshape, # tensor creation "_to_copy.default": self._to_copy, - "lift_fresh_copy.default": self._to_copy, "detach.default": self._detach, "detach_.default": self._detach, "arange.start": self._arange, @@ -395,6 +394,8 @@ def create_convert_map( "empty.memory_format": self._empty, "empty_like.default": self._empty_like, "fill.Scalar": self._fill, + "index_select.default": self._index_select, + "lift_fresh_copy.default": self._to_copy, "new_ones.default": self._new_ones, "one_hot.default": self._one_hot, # other @@ -484,6 +485,7 @@ def from_exported_program( self.env[node] = getattr(exported_program.graph_module, node.target) elif node.op == "call_function": func_name = node.target.__name__ + print("unsing function", func_name) assert ( func_name in self.convert_map ), f"Unsupported function type {func_name}" diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 022a7bffea80..99cde790d63e 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -485,12 +485,6 @@ def _full(self, node: fx.Node) -> relax.Var: ) ) - def _index_select(self, node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - dim = node.args[1] - index = self.env[node.args[2]] - return self.block_builder.emit(relax.op.take(x, index, dim)) - def _inplace_masked_fill(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] mask = self.env[node.args[1]] diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index 64babdc43a5c..4db810e5ee0f 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -467,5 +467,18 @@ def forward(self, x): assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) +@tvm.testing.parametrize_targets("cuda") +def test_index_select(target, dev): + class IndexSelectModel(nn.Module): + def forward(self, x): + indices = torch.tensor([0, 2]) + return torch.index_select(x, 0, indices) + + raw_data = np.random.rand(3, 4).astype("float32") + torch_module = IndexSelectModel().eval() + + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + if __name__ == "__main__": tvm.testing.main() From c39e6e13a466a5c5a9e1174eb9268903bba93779 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Fri, 28 Mar 2025 14:52:49 -0400 Subject: [PATCH 020/105] fix test --- .../relax/test_from_exported_to_cuda.py | 51 ------------------- 1 file changed, 51 deletions(-) diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index 301c1efff5e4..19b8f80a2390 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -476,57 +476,6 @@ def forward(self, x): raw_data = np.random.rand(3, 4).astype("float32") torch_module = IndexSelectModel().eval() - - - torch_module = ChunkModel(chunks=chunks, dim=dim).eval() - assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) - - -@tvm.testing.parametrize_targets("cuda") -def test_chunk_uneven(target, dev): - # Chunks is not a divisor of the dimension size - batch = 2 - channels = 5 - height = 4 - width = 5 - chunks = 2 - dim = 1 - raw_data = np.random.rand(batch, channels, height, width).astype("float32") - - class ChunkModel(nn.Module): - def __init__(self, chunks, dim): - super().__init__() - self.chunks = chunks - self.dim = dim - - def forward(self, x): - return x.chunk(self.chunks, dim=self.dim) - - torch_module = ChunkModel(chunks=chunks, dim=dim).eval() - assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) - - -@tvm.testing.parametrize_targets("cuda") -def test_chunk_too_many(target, dev): - # If user asks for more chunks than the size of the dim, pytorch simply splits in sections of size 1 - batch = 1 - channels = 3 - height = 2 - width = 2 - chunks = 99 - dim = 1 - raw_data = np.random.rand(batch, channels, height, width).astype("float32") - - class ChunkModel(nn.Module): - def __init__(self, chunks, dim): - super().__init__() - self.chunks = chunks - self.dim = dim - - def forward(self, x): - return x.chunk(self.chunks, dim=self.dim) - - torch_module = ChunkModel(chunks=chunks, dim=dim).eval() assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) From f8d50f29f52ab32cd200f4a01b38de2f91baf815 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Fri, 28 Mar 2025 14:59:11 -0400 Subject: [PATCH 021/105] cleanup --- python/tvm/relax/frontend/torch/exported_program_translator.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index cb60c178e448..5f6c4c902401 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -488,7 +488,6 @@ def from_exported_program( self.env[node] = getattr(exported_program.graph_module, node.target) elif node.op == "call_function": func_name = node.target.__name__ - print("unsing function", func_name) assert ( func_name in self.convert_map ), f"Unsupported function type {func_name}" From 086410ca3d257f241714e4c1013b9014f915c26f Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sat, 29 Mar 2025 16:59:20 -0400 Subject: [PATCH 022/105] index_select --- .../relax/frontend/torch/base_fx_graph_translator.py | 6 ++++++ .../frontend/torch/exported_program_translator.py | 3 ++- python/tvm/relax/frontend/torch/fx_translator.py | 6 ------ tests/python/relax/test_from_exported_to_cuda.py | 12 ++++++++++++ 4 files changed, 20 insertions(+), 7 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 71554a8a5bab..aa8150fac71c 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1193,6 +1193,12 @@ def _fill(self, node: fx.Node) -> relax.Var: value = args[1] if isinstance(args[1], relax.Expr) else relax.const(args[1], dtype) return self.block_builder.emit(relax.op.full(x.struct_info.shape, value, dtype)) + def _index_select(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + dim = node.args[1] + index = self.env[node.args[2]] + return self.block_builder.emit(relax.op.take(x, index, dim)) + def _new_ones(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) self_var = args[0] diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 0f1dc11787da..5f6c4c902401 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -389,7 +389,6 @@ def create_convert_map( "reshape.default": self._reshape, # tensor creation "_to_copy.default": self._to_copy, - "lift_fresh_copy.default": self._to_copy, "detach.default": self._detach, "detach_.default": self._detach, "arange.start": self._arange, @@ -398,6 +397,8 @@ def create_convert_map( "empty.memory_format": self._empty, "empty_like.default": self._empty_like, "fill.Scalar": self._fill, + "index_select.default": self._index_select, + "lift_fresh_copy.default": self._to_copy, "new_ones.default": self._new_ones, "one_hot.default": self._one_hot, # other diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 022a7bffea80..99cde790d63e 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -485,12 +485,6 @@ def _full(self, node: fx.Node) -> relax.Var: ) ) - def _index_select(self, node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - dim = node.args[1] - index = self.env[node.args[2]] - return self.block_builder.emit(relax.op.take(x, index, dim)) - def _inplace_masked_fill(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] mask = self.env[node.args[1]] diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index 64babdc43a5c..19b8f80a2390 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -467,5 +467,17 @@ def forward(self, x): assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) +@tvm.testing.parametrize_targets("cuda") +def test_index_select(target, dev): + class IndexSelectModel(nn.Module): + def forward(self, x): + indices = torch.tensor([0, 2]) + return torch.index_select(x, 0, indices) + + raw_data = np.random.rand(3, 4).astype("float32") + torch_module = IndexSelectModel().eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + if __name__ == "__main__": tvm.testing.main() From 97ada56abedc9dfd8262594681c7d68c68824550 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 30 Mar 2025 11:58:50 -0400 Subject: [PATCH 023/105] arange.default ok --- .../torch/exported_program_translator.py | 6 +++-- .../relax/test_from_exported_to_cuda.py | 26 +++++++++++++++++++ 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 0f1dc11787da..684cf344cdd4 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -389,15 +389,16 @@ def create_convert_map( "reshape.default": self._reshape, # tensor creation "_to_copy.default": self._to_copy, - "lift_fresh_copy.default": self._to_copy, + "arange.default": self._arange, + "arange.start": self._arange, "detach.default": self._detach, "detach_.default": self._detach, - "arange.start": self._arange, "contiguous.default": lambda node: self.env[node.args[0]], # no-op "clone.default": lambda node: self.env[node.args[0]], "empty.memory_format": self._empty, "empty_like.default": self._empty_like, "fill.Scalar": self._fill, + "lift_fresh_copy.default": self._to_copy, "new_ones.default": self._new_ones, "one_hot.default": self._one_hot, # other @@ -490,6 +491,7 @@ def from_exported_program( assert ( func_name in self.convert_map ), f"Unsupported function type {func_name}" + print('found function!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!', func_name) self.env[node] = self.convert_map[func_name](node) else: raise ValueError(f"Unsupported op {node.op}") diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index 64babdc43a5c..e77ed464050a 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -466,6 +466,32 @@ def forward(self, x): torch_module = ChunkModel(chunks=chunks, dim=dim).eval() assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) +@tvm.testing.parametrize_targets("cuda") +def test_arange_default(target, dev): + raw_data = np.random.rand(5).astype("int64") + + class ArangeModel(nn.Module): + def forward(self, x): + return x + torch.arange(5) + + torch_module = ArangeModel().eval() + + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + +# TODO +# @tvm.testing.parametrize_targets("cuda") +# def test_arange_start_step(target, dev): +# raw_data = np.random.rand(3).astype("int64") + +# class ArangeModel(nn.Module): +# def forward(self, x): +# return x + torch.arange(1, 2.5, 0.5) + +# torch_module = ArangeModel().eval() + +# assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + if __name__ == "__main__": tvm.testing.main() From 3431e183fee7f1a3b1f686529edda0e24b1cfc2c Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 30 Mar 2025 12:03:29 -0400 Subject: [PATCH 024/105] all arange tests pass --- .../torch/exported_program_translator.py | 3 +- .../relax/test_from_exported_to_cuda.py | 31 +++++++++++++------ 2 files changed, 23 insertions(+), 11 deletions(-) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 684cf344cdd4..7cc071d01e33 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -390,7 +390,8 @@ def create_convert_map( # tensor creation "_to_copy.default": self._to_copy, "arange.default": self._arange, - "arange.start": self._arange, + "arange.start": self._arange, # TODO test + "arange.start_step": self._arange, "detach.default": self._detach, "detach_.default": self._detach, "contiguous.default": lambda node: self.env[node.args[0]], # no-op diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index e77ed464050a..9a0cd2c67cda 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -468,7 +468,7 @@ def forward(self, x): @tvm.testing.parametrize_targets("cuda") def test_arange_default(target, dev): - raw_data = np.random.rand(5).astype("int64") + raw_data = np.array([0,0,0,0,0]) class ArangeModel(nn.Module): def forward(self, x): @@ -478,19 +478,30 @@ def forward(self, x): assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) +@tvm.testing.parametrize_targets("cuda") +def test_arange_start(target, dev): + raw_data = np.array([0,0,0]) + + class ArangeModel(nn.Module): + def forward(self, x): + return x + torch.arange(1, 4) + + torch_module = ArangeModel().eval() + + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) -# TODO -# @tvm.testing.parametrize_targets("cuda") -# def test_arange_start_step(target, dev): -# raw_data = np.random.rand(3).astype("int64") -# class ArangeModel(nn.Module): -# def forward(self, x): -# return x + torch.arange(1, 2.5, 0.5) +@tvm.testing.parametrize_targets("cuda") +def test_arange_start_step(target, dev): + raw_data = np.array([0.0,0.0,0.0], dtype=np.float32) -# torch_module = ArangeModel().eval() + class ArangeModel(nn.Module): + def forward(self, x): + return x + torch.arange(1, 2.5, 0.5, dtype=torch.float32) -# assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + torch_module = ArangeModel().eval() + + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) if __name__ == "__main__": From 5c0aa4f824bb63dfe5c104525b9a6ebcdaf5180d Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 30 Mar 2025 12:06:25 -0400 Subject: [PATCH 025/105] arange test complete --- .../torch/exported_program_translator.py | 3 +- .../relax/test_from_exported_to_cuda.py | 32 ++++++++----------- 2 files changed, 15 insertions(+), 20 deletions(-) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 7cc071d01e33..c0052280433b 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -390,7 +390,7 @@ def create_convert_map( # tensor creation "_to_copy.default": self._to_copy, "arange.default": self._arange, - "arange.start": self._arange, # TODO test + "arange.start": self._arange, "arange.start_step": self._arange, "detach.default": self._detach, "detach_.default": self._detach, @@ -492,7 +492,6 @@ def from_exported_program( assert ( func_name in self.convert_map ), f"Unsupported function type {func_name}" - print('found function!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!', func_name) self.env[node] = self.convert_map[func_name](node) else: raise ValueError(f"Unsupported op {node.op}") diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index 9a0cd2c67cda..5a49c6f5f434 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -466,41 +466,37 @@ def forward(self, x): torch_module = ChunkModel(chunks=chunks, dim=dim).eval() assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + @tvm.testing.parametrize_targets("cuda") -def test_arange_default(target, dev): - raw_data = np.array([0,0,0,0,0]) +def test_arange(target, dev): + # arange.default + raw_data = np.array([0, 0, 0, 0, 0]) - class ArangeModel(nn.Module): + class ArangeDefaultModel(nn.Module): def forward(self, x): return x + torch.arange(5) - torch_module = ArangeModel().eval() - + torch_module = ArangeDefaultModel().eval() assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) -@tvm.testing.parametrize_targets("cuda") -def test_arange_start(target, dev): - raw_data = np.array([0,0,0]) + # arange.start + raw_data = np.array([0, 0, 0]) - class ArangeModel(nn.Module): + class ArangeStartModel(nn.Module): def forward(self, x): return x + torch.arange(1, 4) - torch_module = ArangeModel().eval() - + torch_module = ArangeStartModel().eval() assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + # arange.start_step + raw_data = np.array([0.0, 0.0, 0.0], dtype=np.float32) -@tvm.testing.parametrize_targets("cuda") -def test_arange_start_step(target, dev): - raw_data = np.array([0.0,0.0,0.0], dtype=np.float32) - - class ArangeModel(nn.Module): + class ArangeStartStopModel(nn.Module): def forward(self, x): return x + torch.arange(1, 2.5, 0.5, dtype=torch.float32) - torch_module = ArangeModel().eval() - + torch_module = ArangeStartStopModel().eval() assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) From 353c399b2885b63555d45c9d10fc6fc5b074522a Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 30 Mar 2025 22:00:08 -0400 Subject: [PATCH 026/105] dummy tensor.Index compiles - ready to test runnign --- include/tvm/relax/attrs/manipulate.h | 6 ++ .../torch/base_fx_graph_translator.py | 5 ++ .../torch/exported_program_translator.py | 1 + python/tvm/relax/op/__init__.py | 1 + python/tvm/relax/op/manipulate.py | 6 ++ python/tvm/relax/op/op_attrs.py | 7 ++ .../transform/legalize_ops/manipulate.py | 4 + python/tvm/topi/transform.py | 8 ++ src/relax/op/tensor/manipulate.cc | 85 +++++++++++++++++++ src/relax/op/tensor/manipulate.h | 13 +++ 10 files changed, 136 insertions(+) diff --git a/include/tvm/relax/attrs/manipulate.h b/include/tvm/relax/attrs/manipulate.h index e6c16d233a6b..82b299f5f1f5 100644 --- a/include/tvm/relax/attrs/manipulate.h +++ b/include/tvm/relax/attrs/manipulate.h @@ -169,6 +169,12 @@ struct GatherNDAttrs : public tvm::AttrsNode { } }; // struct GatherNDAttrs +/*! \brief Attributes used in index_tensor operators */ +struct IndexTensorAttrs : public tvm::AttrsNode { + // TODO is this needed if we just don't have arguments? + TVM_DECLARE_ATTRS(IndexTensorAttrs, "relax.attrs.IndexTensorAttrs") {} +}; // struct IndexTensorAttrs + /*! \brief Attributes used in scatter_elements operators */ struct ScatterElementsAttrs : public tvm::AttrsNode { Integer axis; diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index aa8150fac71c..860337747424 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1014,6 +1014,11 @@ def _gather(self, node: fx.Node) -> relax.Var: index = self.env[node.args[2]] return self.block_builder.emit(relax.op.gather_elements(x, index, axis=dim)) + def _index_tensor(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + index = self.env[node.args[1]] + return self.block_builder.emit(relax.op.index_tensor(x, index)) + def _permute(self, node: fx.Node) -> relax.Var: import torch # type: ignore diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index a072ebaf98a7..04be9b6e12f8 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -493,6 +493,7 @@ def from_exported_program( assert ( func_name in self.convert_map ), f"Unsupported function type {func_name}" + print("found function", func_name, "!!!!!!!!!!!!!!!!!!!!!!!!!!!!") self.env[node] = self.convert_map[func_name](node) else: raise ValueError(f"Unsupported op {node.op}") diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py index 97f18a239640..f81c5448a024 100644 --- a/python/tvm/relax/op/__init__.py +++ b/python/tvm/relax/op/__init__.py @@ -94,6 +94,7 @@ flip, gather_elements, gather_nd, + index_tensor, layout_transform, one_hot, permute_dims, diff --git a/python/tvm/relax/op/manipulate.py b/python/tvm/relax/op/manipulate.py index 0f6e537ab3d6..5fa79eae63a0 100644 --- a/python/tvm/relax/op/manipulate.py +++ b/python/tvm/relax/op/manipulate.py @@ -507,6 +507,12 @@ def gather_nd(data: Expr, indices: Expr, batch_dims: int = 0) -> Expr: """ return _ffi_api.gather_nd(data, indices, batch_dims) # type: ignore +def index_tensor(data: Expr, indices: Expr) -> Expr: + """ + TODO docstring + """ + return _ffi_api.index_tensor(data, indices) # type: ignore + def scatter_elements( data: Expr, indices: Expr, updates: Expr, axis: int = 0, reduction: str = "update" diff --git a/python/tvm/relax/op/op_attrs.py b/python/tvm/relax/op/op_attrs.py index 4658950f511a..b80ef3512d4d 100644 --- a/python/tvm/relax/op/op_attrs.py +++ b/python/tvm/relax/op/op_attrs.py @@ -182,3 +182,10 @@ class EinsumAttrs(Attrs): @tvm._ffi.register_object("relax.attrs.FlipAttrs") class FlipAttrs(Attrs): """Attributes for flip operator""" + + +# TODO is this needed? It looks like not all ops are here +@tvm._ffi.register_object("relax.attrs.IndexTensorAttrs") +class IndexTensorAttrs(Attrs): + """Attributes used in index_tensor operator""" + diff --git a/python/tvm/relax/transform/legalize_ops/manipulate.py b/python/tvm/relax/transform/legalize_ops/manipulate.py index 662d4e946b5f..42ff2ac4f050 100644 --- a/python/tvm/relax/transform/legalize_ops/manipulate.py +++ b/python/tvm/relax/transform/legalize_ops/manipulate.py @@ -161,6 +161,10 @@ def te_gather_nd(data, indices, batch_dims): return bb.call_te(te_gather_nd, call.args[0], call.args[1], int(call.attrs.batch_dims)) +# TODO what does this do? +@register_legalize("relax.index_tensor") +def _index_tensor(bb: BlockBuilder, call: Call) -> Expr: + return bb.call_te(topi.index_tensor, call.args[0], call.args[1]) # TODO should I use primfunc_name_hint? @register_legalize("relax.scatter_elements") def _scatter_elements(bb: BlockBuilder, call: Call) -> Expr: diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py index b8605aa58a2e..cb7a7ff0c927 100644 --- a/python/tvm/topi/transform.py +++ b/python/tvm/topi/transform.py @@ -1052,3 +1052,11 @@ def _apply_trilu(*indices): return tvm.tir.Select(check_position, value, tvm.tir.const(0, data.dtype)) return te.compute(data.shape, _apply_trilu, name="trilu", tag=topi.tag.ELEMWISE) + + +def index_tensor(data, indices): + """ TODO docstring """ + # TODO actually implement! + print("we have reached index_tensor in topi !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") + return topi.sum(data, axis=[0]) + diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index cb738db363ee..876608a51b01 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -30,6 +30,7 @@ #include #include +#include "tvm/relax/type.h" // kUnknownNDim #include "tvm/runtime/data_type.h" namespace tvm { @@ -474,6 +475,90 @@ TVM_REGISTER_OP("relax.flatten") .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", Bool(true)); +/* relax.index_tensor */ +TVM_REGISTER_NODE_TYPE(IndexTensorAttrs); + +Expr index_tensor(Expr x, Expr indices) { + ObjectPtr attrs = make_object(); + + static const Op& op = Op::Get("relax.index_tensor"); + return Call(op, {std::move(x), std::move(indices)}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.index_tensor").set_body_typed(index_tensor); + +// TODO understand every line here? Is this all correct? +StructInfo InferStructInfoIndexTensor(const Call& call, const BlockBuilder& ctx) { + CheckNumArguments(call, ctx); + TensorStructInfo data_sinfo = GetInputTensorStructInfo(call, 0, ctx); + + // StructInfo inference when the index is a PrimValue is equivalent + // to that of a scalar (0-d) tensor. + TensorStructInfo indices_sinfo = [&]() { + auto arg = call->args[1]; + auto sinfo = GetStructInfo(arg); + // TODO update the condition below. The indices argument should always be a tensor, it cannot be + // a scalar value + if (auto tensor_sinfo = sinfo.as()) { + return tensor_sinfo.value(); + } else if (auto prim_sinfo = sinfo.as()) { + return TensorStructInfo(ShapeExpr(Array{}), prim_sinfo->dtype); + } else { + ctx->ReportFatal(Diagnostic::Error(call) + << "Operator " << call->op << " requires the indices argument to be " + << "either a tensor or a scalar value. " + << "However, argument " << arg << " has struct info " << sinfo); + // Unreachable, but [[noreturn]] attribute on virtual function + // `ReportFatal` is insufficient to silence -Wreturn-type, as + // child class might not be [[noreturn]]. + return TensorStructInfo(); + } + }(); + + if (indices_sinfo->IsUnknownDtype()) { + // TODO(tvm-team): Do we have an equivalent of `ctx->ReportFatal` for warning? + LOG(WARNING) << "Data type of indice has not been specified. Assume it has an integer type."; + } else if (!(indices_sinfo->dtype.is_int() || indices_sinfo->dtype.is_uint())) { + ctx->ReportFatal( + Diagnostic::Error(call) + << "Index Tensor op requires the input indices to have integer dtype. However, the " + "given indices dtype is " + << indices_sinfo->dtype); + } + + if (data_sinfo->IsUnknownNdim() || indices_sinfo->IsUnknownNdim()) { + return TensorStructInfo(data_sinfo->dtype, kUnknownNDim, data_sinfo->vdevice); + } + + const auto* attrs = call->attrs.as(); + + const auto* data_shape = data_sinfo->shape.as(); + const auto* indices_shape = indices_sinfo->shape.as(); + if (data_shape == nullptr || indices_shape == nullptr) { + return TensorStructInfo(data_sinfo->dtype, indices_sinfo->ndim + data_sinfo->ndim - 1, + data_sinfo->vdevice); + } + + // TODO can we do better than kUnknownNDim, and instead do something like this for the output + // shape? Array output_shape; for (int i = 0; i < data_sinfo->ndim; i++) { + // if (i == axis) { + // for (int j = 0; j < indices_sinfo->ndim; j++) + // output_shape.push_back(indices_shape->values[j]); + // } else { + // output_shape.push_back(data_shape->values[i]); + // } + // } + return TensorStructInfo(data_sinfo->dtype, kUnknownNDim, data_sinfo->vdevice); +} + +TVM_REGISTER_OP("relax.index_tensor") + .set_attrs_type() + .set_num_inputs(2) + .add_argument("x", "Tensor", "The input tensor.") + .add_argument("indices", "Tensor", "The indices of the values to extract.") + .set_attr("FInferStructInfo", InferStructInfoIndexTensor) + .set_attr("FPurity", Bool(true)); + /* relax.layout_transform */ TVM_REGISTER_NODE_TYPE(LayoutTransformAttrs); diff --git a/src/relax/op/tensor/manipulate.h b/src/relax/op/tensor/manipulate.h index 1a0c7ddbc76c..0c7d482ffacb 100644 --- a/src/relax/op/tensor/manipulate.h +++ b/src/relax/op/tensor/manipulate.h @@ -200,6 +200,19 @@ Expr gather_elements(Expr data, Expr indices, int axis = 0); */ Expr gather_nd(Expr data, Expr indices, int batch_dims = 0); +/*! // TODO update this comment + * \brief Gather values from a tensor using N-dimensional indices. + * \param data The input tensor. + * \param indices The indices tensor, must have integer type. + * \return The computed result. + * + * \note For batch_dims > 0, the first batch_dims dimensions of data and indices must be equal. + * The last dimension of indices indicates the depth of each index vector. + * The output shape is batch_dims + indices.shape[:-1] + data.shape[batch_dims + + * indices.shape[-1]:] + */ +Expr index_tensor(Expr data, Expr indices); + /*! * \brief Scatter updates into an array according to indices. * \param data The input tensor. From 3ed1f1b3454fcfa043813a99c71d77c55ef4da55 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 30 Mar 2025 22:08:46 -0400 Subject: [PATCH 027/105] type error in dummy test --- .../relax/frontend/torch/base_fx_graph_translator.py | 10 +++++++--- .../frontend/torch/exported_program_translator.py | 1 + 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 860337747424..87df43e7cf76 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1015,9 +1015,13 @@ def _gather(self, node: fx.Node) -> relax.Var: return self.block_builder.emit(relax.op.gather_elements(x, index, axis=dim)) def _index_tensor(self, node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - index = self.env[node.args[1]] - return self.block_builder.emit(relax.op.index_tensor(x, index)) +# ? x = self.env[node.args[0]] + # indices = node.args[1] + args = self.retrieve_args(node) + + # index = self.env[node.args[1]] # TODO + return self.block_builder.emit(relax.op.index_tensor(args[0], args[1])) + # return self.block_builder.emit(relax.op.index_tensor(x, indices)) def _permute(self, node: fx.Node) -> relax.Var: import torch # type: ignore diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 04be9b6e12f8..dd7c0d07ca4a 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -371,6 +371,7 @@ def create_convert_map( "expand_as.default": self._expand_as, "flip.default": self._flip, "gather.default": self._gather, + "index.Tensor": self._index_tensor, "permute.default": self._permute, "repeat.default": self._repeat, "select.int": self._select, From 1d1606c083f61895909e69c2a1de7db1c7b4591a Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Mon, 31 Mar 2025 00:47:16 -0400 Subject: [PATCH 028/105] failing a check --- include/tvm/relax/attrs/manipulate.h | 6 ++++-- .../relax/frontend/torch/base_fx_graph_translator.py | 10 +++++++++- python/tvm/relax/op/manipulate.py | 7 ++++++- python/tvm/relax/transform/legalize_ops/manipulate.py | 2 +- src/relax/op/tensor/manipulate.cc | 10 +++++----- src/relax/op/tensor/manipulate.h | 2 +- 6 files changed, 26 insertions(+), 11 deletions(-) diff --git a/include/tvm/relax/attrs/manipulate.h b/include/tvm/relax/attrs/manipulate.h index 82b299f5f1f5..91fbee8c591d 100644 --- a/include/tvm/relax/attrs/manipulate.h +++ b/include/tvm/relax/attrs/manipulate.h @@ -171,8 +171,10 @@ struct GatherNDAttrs : public tvm::AttrsNode { /*! \brief Attributes used in index_tensor operators */ struct IndexTensorAttrs : public tvm::AttrsNode { - // TODO is this needed if we just don't have arguments? - TVM_DECLARE_ATTRS(IndexTensorAttrs, "relax.attrs.IndexTensorAttrs") {} + Array indices; // TODO will need to extend this, since could be an array of arrays? + TVM_DECLARE_ATTRS(IndexTensorAttrs, "relax.attrs.IndexTensorAttrs") { + TVM_ATTR_FIELD(indices).describe("The indices to select."); + } }; // struct IndexTensorAttrs /*! \brief Attributes used in scatter_elements operators */ diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 87df43e7cf76..3773ae505322 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1018,9 +1018,17 @@ def _index_tensor(self, node: fx.Node) -> relax.Var: # ? x = self.env[node.args[0]] # indices = node.args[1] args = self.retrieve_args(node) + print("len of args", len(args)) + print("type of args[0]", type(args[0])) + print("args[0]", args[0]) + print("type of args[1]", type(args[1])) # Is a list no matter what!!! Like even if we pass a torch.tensor + print("args[1]", args[1]) + + # indices = args[1] # TODO do something like this! + indices = [2,3] # index = self.env[node.args[1]] # TODO - return self.block_builder.emit(relax.op.index_tensor(args[0], args[1])) + return self.block_builder.emit(relax.op.index_tensor(args[0], indices)) # return self.block_builder.emit(relax.op.index_tensor(x, indices)) def _permute(self, node: fx.Node) -> relax.Var: diff --git a/python/tvm/relax/op/manipulate.py b/python/tvm/relax/op/manipulate.py index 5fa79eae63a0..0ffebb97e509 100644 --- a/python/tvm/relax/op/manipulate.py +++ b/python/tvm/relax/op/manipulate.py @@ -507,10 +507,15 @@ def gather_nd(data: Expr, indices: Expr, batch_dims: int = 0) -> Expr: """ return _ffi_api.gather_nd(data, indices, batch_dims) # type: ignore -def index_tensor(data: Expr, indices: Expr) -> Expr: +def index_tensor(data: Expr, indices: List[int]) -> Expr: """ TODO docstring """ + # TODO loosen those assertions! Need to handler lists of lists of lists etc. + assert isinstance(indices, list), "indices should be a list" + assert all(isinstance(i, int) for i in indices), "indices should be a list of integers, but got {}".format( + [type(i) for i in indices] + ) return _ffi_api.index_tensor(data, indices) # type: ignore diff --git a/python/tvm/relax/transform/legalize_ops/manipulate.py b/python/tvm/relax/transform/legalize_ops/manipulate.py index 42ff2ac4f050..8eeb9c026dd4 100644 --- a/python/tvm/relax/transform/legalize_ops/manipulate.py +++ b/python/tvm/relax/transform/legalize_ops/manipulate.py @@ -164,7 +164,7 @@ def te_gather_nd(data, indices, batch_dims): # TODO what does this do? @register_legalize("relax.index_tensor") def _index_tensor(bb: BlockBuilder, call: Call) -> Expr: - return bb.call_te(topi.index_tensor, call.args[0], call.args[1]) # TODO should I use primfunc_name_hint? + return bb.call_te(topi.index_tensor, call.args[0], call.attrs.indices) # TODO should I use primfunc_name_hint? @register_legalize("relax.scatter_elements") def _scatter_elements(bb: BlockBuilder, call: Call) -> Expr: diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index 876608a51b01..f77193fd3c91 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -478,11 +478,12 @@ TVM_REGISTER_OP("relax.flatten") /* relax.index_tensor */ TVM_REGISTER_NODE_TYPE(IndexTensorAttrs); -Expr index_tensor(Expr x, Expr indices) { - ObjectPtr attrs = make_object(); +Expr index_tensor(Expr x, Array indices) { + auto attrs = make_object(); + attrs->indices = std::move(indices); static const Op& op = Op::Get("relax.index_tensor"); - return Call(op, {std::move(x), std::move(indices)}, Attrs(attrs), {}); + return Call(op, {std::move(x)}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relax.op.index_tensor").set_body_typed(index_tensor); @@ -553,9 +554,8 @@ StructInfo InferStructInfoIndexTensor(const Call& call, const BlockBuilder& ctx) TVM_REGISTER_OP("relax.index_tensor") .set_attrs_type() - .set_num_inputs(2) + .set_num_inputs(1) .add_argument("x", "Tensor", "The input tensor.") - .add_argument("indices", "Tensor", "The indices of the values to extract.") .set_attr("FInferStructInfo", InferStructInfoIndexTensor) .set_attr("FPurity", Bool(true)); diff --git a/src/relax/op/tensor/manipulate.h b/src/relax/op/tensor/manipulate.h index 0c7d482ffacb..8f88d40ccdb4 100644 --- a/src/relax/op/tensor/manipulate.h +++ b/src/relax/op/tensor/manipulate.h @@ -211,7 +211,7 @@ Expr gather_nd(Expr data, Expr indices, int batch_dims = 0); * The output shape is batch_dims + indices.shape[:-1] + data.shape[batch_dims + * indices.shape[-1]:] */ -Expr index_tensor(Expr data, Expr indices); +Expr index_tensor(Expr data, Array indices); /*! * \brief Scatter updates into an array according to indices. From 3ce339ccb6d8c0e2286dda494940e9a51f91f3f3 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Mon, 31 Mar 2025 12:24:17 -0400 Subject: [PATCH 029/105] codegen error --- src/relax/op/tensor/manipulate.cc | 87 ++++++++++++++++--------------- 1 file changed, 44 insertions(+), 43 deletions(-) diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index f77193fd3c91..6f238107f010 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -493,52 +493,53 @@ StructInfo InferStructInfoIndexTensor(const Call& call, const BlockBuilder& ctx) CheckNumArguments(call, ctx); TensorStructInfo data_sinfo = GetInputTensorStructInfo(call, 0, ctx); - // StructInfo inference when the index is a PrimValue is equivalent - // to that of a scalar (0-d) tensor. - TensorStructInfo indices_sinfo = [&]() { - auto arg = call->args[1]; - auto sinfo = GetStructInfo(arg); - // TODO update the condition below. The indices argument should always be a tensor, it cannot be - // a scalar value - if (auto tensor_sinfo = sinfo.as()) { - return tensor_sinfo.value(); - } else if (auto prim_sinfo = sinfo.as()) { - return TensorStructInfo(ShapeExpr(Array{}), prim_sinfo->dtype); - } else { - ctx->ReportFatal(Diagnostic::Error(call) - << "Operator " << call->op << " requires the indices argument to be " - << "either a tensor or a scalar value. " - << "However, argument " << arg << " has struct info " << sinfo); - // Unreachable, but [[noreturn]] attribute on virtual function - // `ReportFatal` is insufficient to silence -Wreturn-type, as - // child class might not be [[noreturn]]. - return TensorStructInfo(); - } - }(); - - if (indices_sinfo->IsUnknownDtype()) { - // TODO(tvm-team): Do we have an equivalent of `ctx->ReportFatal` for warning? - LOG(WARNING) << "Data type of indice has not been specified. Assume it has an integer type."; - } else if (!(indices_sinfo->dtype.is_int() || indices_sinfo->dtype.is_uint())) { - ctx->ReportFatal( - Diagnostic::Error(call) - << "Index Tensor op requires the input indices to have integer dtype. However, the " - "given indices dtype is " - << indices_sinfo->dtype); - } + // TODO the commented out checks below fail, understand why! + // // StructInfo inference when the index is a PrimValue is equivalent + // // to that of a scalar (0-d) tensor. + // TensorStructInfo indices_sinfo = [&]() { + // auto arg = call->args[0]; // TODO changed this from 1 to 0, is that ok? + // auto sinfo = GetStructInfo(arg); + // // TODO update the condition below. The indices argument should always be a tensor, it cannot be + // // a scalar value + // if (auto tensor_sinfo = sinfo.as()) { + // return tensor_sinfo.value(); + // } else if (auto prim_sinfo = sinfo.as()) { + // return TensorStructInfo(ShapeExpr(Array{}), prim_sinfo->dtype); + // } else { + // ctx->ReportFatal(Diagnostic::Error(call) + // << "Operator " << call->op << " requires the indices argument to be " + // << "either a tensor or a scalar value. " + // << "However, argument " << arg << " has struct info " << sinfo); + // // Unreachable, but [[noreturn]] attribute on virtual function + // // `ReportFatal` is insufficient to silence -Wreturn-type, as + // // child class might not be [[noreturn]]. + // return TensorStructInfo(); + // } + // }(); + + // if (indices_sinfo->IsUnknownDtype()) { + // // TODO(tvm-team): Do we have an equivalent of `ctx->ReportFatal` for warning? + // LOG(WARNING) << "Data type of indice has not been specified. Assume it has an integer type."; + // } else if (!(indices_sinfo->dtype.is_int() || indices_sinfo->dtype.is_uint())) { + // ctx->ReportFatal( + // Diagnostic::Error(call) + // << "Index Tensor op requires the input indices to have integer dtype. However, the " + // "given indices dtype is " + // << indices_sinfo->dtype); + // } - if (data_sinfo->IsUnknownNdim() || indices_sinfo->IsUnknownNdim()) { - return TensorStructInfo(data_sinfo->dtype, kUnknownNDim, data_sinfo->vdevice); - } + // if (data_sinfo->IsUnknownNdim() || indices_sinfo->IsUnknownNdim()) { + // return TensorStructInfo(data_sinfo->dtype, kUnknownNDim, data_sinfo->vdevice); + // } - const auto* attrs = call->attrs.as(); + // const auto* attrs = call->attrs.as(); - const auto* data_shape = data_sinfo->shape.as(); - const auto* indices_shape = indices_sinfo->shape.as(); - if (data_shape == nullptr || indices_shape == nullptr) { - return TensorStructInfo(data_sinfo->dtype, indices_sinfo->ndim + data_sinfo->ndim - 1, - data_sinfo->vdevice); - } + // const auto* data_shape = data_sinfo->shape.as(); + // const auto* indices_shape = indices_sinfo->shape.as(); + // if (data_shape == nullptr || indices_shape == nullptr) { + // return TensorStructInfo(data_sinfo->dtype, indices_sinfo->ndim + data_sinfo->ndim - 1, + // data_sinfo->vdevice); + // } // TODO can we do better than kUnknownNDim, and instead do something like this for the output // shape? Array output_shape; for (int i = 0; i < data_sinfo->ndim; i++) { From a8c7185f619d1e8021962418b7c48157ab105aa0 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Mon, 31 Mar 2025 13:02:25 -0400 Subject: [PATCH 030/105] code gen error --- python/tvm/topi/transform.py | 124 ++++++++++++++++++++++++++++++++++- 1 file changed, 121 insertions(+), 3 deletions(-) diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py index cb7a7ff0c927..47b57de0263f 100644 --- a/python/tvm/topi/transform.py +++ b/python/tvm/topi/transform.py @@ -1055,8 +1055,126 @@ def _apply_trilu(*indices): def index_tensor(data, indices): - """ TODO docstring """ - # TODO actually implement! + """ TODO docstring + Replicate data[indices] using only: + - basic indexing on data + - torch.index_select + - concatenation/stack + - broadcasting + … and no advanced indexing. + + Approach for multiple advanced indices: broadcast and loop + + Approach for single advanced index: + 1. Convert the nested Python list to a LongTensor. + 2. Remove exactly one leading dimension of size=1, if present. (Matches PyTorch's shape rule.) + 3. Flatten -> fix negative indices -> index_select -> reshape. + """ + print("we have reached index_tensor in topi !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") - return topi.sum(data, axis=[0]) + return topi.sum(data, axis=[0]) # TODO remove + + # TODO convert everything below in topi-syntax + + def _prod(shape): + """Compute the product of all dimensions in 'shape'.""" + out = 1 + for s in shape: + out *= s + return out + + def _is_multiple_indices(indices): + """ + Decide if 'indices' is multiple parallel indices vs. a single advanced index. + - If the top-level is a list/tuple and len(indices) > 1, interpret it as multiple indices. + - Otherwise, it's a single advanced index. + """ + if isinstance(indices, (list, tuple)): + if len(indices) > 1: + return True + return False + + + # ----------------------------------------------------------- + # CASE B: multiple advanced indices + # ----------------------------------------------------------- + if _is_multiple_indices(indices): + # e.g. data[[0,2],[1,3]] => separate indices for dim=0, dim=1 + idx_list = [] + for sub_i in indices: + idx_list.append(torch.tensor(sub_i, dtype=torch.long)) + + # 1) Broadcast them to a common shape B + shapes = [x.shape for x in idx_list] + B = torch.broadcast_shapes(*shapes) + + # 2) Expand each index to that shape + for i in range(len(idx_list)): + idx_list[i] = idx_list[i].expand(B) + + k = len(idx_list) # number of advanced dims + leftover_dims = data.shape[k:] # leftover dims after those k + + M = _prod(B) + # 3) Flatten each index to length=M + for i in range(k): + idx_list[i] = idx_list[i].reshape(M) + + # 4) Enumerate each broadcasted coordinate => basic scalar indexing + slices = [] + for n in range(M): + out_slice = data + for i in range(k): + scalar_idx = idx_list[i][n].item() + # handle negative indexing if you want: + if scalar_idx < 0: + scalar_idx += data.shape[i] + out_slice = out_slice[scalar_idx] + slices.append(out_slice.unsqueeze(0)) # shape [1, leftover_dims] + + # 5) Concatenate -> shape [M, leftover_dims] + stacked = torch.cat(slices, dim=0) + # 6) Reshape -> [B, leftover_dims] + final_shape = list(B) + list(leftover_dims) + result = stacked.view(*final_shape) + return result + + # ----------------------------------------------------------- + # CASE A: single advanced index + # ----------------------------------------------------------- + else: + # 1) Convert the nested Python list -> a LongTensor + # This is allowed. It's not advanced indexing on 'data', + # just building an index tensor from a Python list. + idx_t = torch.tensor(indices, dtype=torch.long) + + # 2) If there's at least one dimension and the first dimension is size=1, + # remove exactly one leading dim. + # (This matches PyTorch's "merge away the top-level [1]" rule for a single advanced index.) + if idx_t.dim() > 0 and idx_t.shape[0] == 1: + idx_t = idx_t.squeeze(0) # remove exactly one leading dim + + # 3) Flatten -> fix negative indices -> index_select + flattened = idx_t.reshape(-1) + # fix negative indices if desired: + # for i in range(flattened.size(0)): + # if flattened[i] < 0: + # flattened[i] = flattened[i] + data.shape[0] + + # we can do this in a vectorized manner if needed: + # but for brevity, let's skip negative idx correction or assume they're in range [0, data.shape[0]-1] + + # picked = torch.index_select(data, dim=0, index=flattened) + picked = take(a=data, indices=flattened, axis=0) + + + # leftover dims + leftover_dims = data.shape[1:] + # final shape = idx_t.shape + leftover_dims + adv_shape = idx_t.shape + final_shape = list(adv_shape) + list(leftover_dims) + result = picked.view(*final_shape) + return result + + From eeefce6b82a97af6bd3bc5481bc75a717eb0e8b1 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Mon, 31 Mar 2025 13:47:39 -0400 Subject: [PATCH 031/105] explode broadcast_shapes --- python/tvm/topi/transform.py | 25 +++++++++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py index 47b57de0263f..b8682e28ab8d 100644 --- a/python/tvm/topi/transform.py +++ b/python/tvm/topi/transform.py @@ -1083,6 +1083,27 @@ def _prod(shape): out *= s return out + def _broadcast_shapes(shapes): + # equivalent to `return torch.broadcast_shapes(*shapes)`, but can't find how to have broadcast_shapes in TVM. + # TODO Try to understand what exported translator does when I pass broadcast_shapes, since cuda_export_index_broadcat_shape + """ + Re-implementation of torch.broadcast_shapes since not sure how to call torch.broadcast_shapes(*shapes) in topi + """ + max_ndim = max(len(s) for s in shapes) + rev_shapes = [s[::-1] for s in shapes] + out = [] + for i in range(max_ndim): + dim_size = 1 + for rsh in rev_shapes: + if i < len(rsh): + s_ = rsh[i] + if s_ != 1 and dim_size != 1 and s_ != dim_size: + raise ValueError(f"Incompatible shapes for broadcast: {shapes}") + dim_size = max(dim_size, s_) + out.append(dim_size) + out.reverse() + return tuple(out) + def _is_multiple_indices(indices): """ Decide if 'indices' is multiple parallel indices vs. a single advanced index. @@ -1106,8 +1127,8 @@ def _is_multiple_indices(indices): # 1) Broadcast them to a common shape B shapes = [x.shape for x in idx_list] - B = torch.broadcast_shapes(*shapes) - + B = _broadcast_shapes(shapes) + # 2) Expand each index to that shape for i in range(len(idx_list)): idx_list[i] = idx_list[i].expand(B) From 3b8a4a0d0e0229945f43897efec185e78fb43f2c Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Mon, 31 Mar 2025 14:33:07 -0400 Subject: [PATCH 032/105] debugging --- python/tvm/topi/transform.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py index b8682e28ab8d..1dbfdbc71b0e 100644 --- a/python/tvm/topi/transform.py +++ b/python/tvm/topi/transform.py @@ -1071,8 +1071,8 @@ def index_tensor(data, indices): 3. Flatten -> fix negative indices -> index_select -> reshape. """ - print("we have reached index_tensor in topi !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") - return topi.sum(data, axis=[0]) # TODO remove + print("we have reached index_tensor in topi !") + return topi.sum(data, axis=[0]) # TODO used for debugging, remove # TODO convert everything below in topi-syntax From 4b990e40b6c24b9ed552c24d8da5415aa249c90d Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 6 Apr 2025 10:52:49 -0400 Subject: [PATCH 033/105] first week of work, index.Tensor branch --- include/tvm/relax/attrs/manipulate.h | 8 + .../torch/base_fx_graph_translator.py | 17 ++ .../torch/exported_program_translator.py | 2 + python/tvm/relax/op/__init__.py | 1 + python/tvm/relax/op/manipulate.py | 11 ++ python/tvm/relax/op/op_attrs.py | 7 + .../transform/legalize_ops/manipulate.py | 4 + python/tvm/topi/transform.py | 147 ++++++++++++++++++ src/relax/op/tensor/manipulate.cc | 86 ++++++++++ src/relax/op/tensor/manipulate.h | 13 ++ 10 files changed, 296 insertions(+) diff --git a/include/tvm/relax/attrs/manipulate.h b/include/tvm/relax/attrs/manipulate.h index e6c16d233a6b..91fbee8c591d 100644 --- a/include/tvm/relax/attrs/manipulate.h +++ b/include/tvm/relax/attrs/manipulate.h @@ -169,6 +169,14 @@ struct GatherNDAttrs : public tvm::AttrsNode { } }; // struct GatherNDAttrs +/*! \brief Attributes used in index_tensor operators */ +struct IndexTensorAttrs : public tvm::AttrsNode { + Array indices; // TODO will need to extend this, since could be an array of arrays? + TVM_DECLARE_ATTRS(IndexTensorAttrs, "relax.attrs.IndexTensorAttrs") { + TVM_ATTR_FIELD(indices).describe("The indices to select."); + } +}; // struct IndexTensorAttrs + /*! \brief Attributes used in scatter_elements operators */ struct ScatterElementsAttrs : public tvm::AttrsNode { Integer axis; diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index affbd81e1c28..c26dc89f2348 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1058,6 +1058,23 @@ def _gather(self, node: fx.Node) -> relax.Var: index = self.env[node.args[2]] return self.block_builder.emit(relax.op.gather_elements(x, index, axis=dim)) + def _index_tensor(self, node: fx.Node) -> relax.Var: +# ? x = self.env[node.args[0]] + # indices = node.args[1] + args = self.retrieve_args(node) + print("len of args", len(args)) + print("type of args[0]", type(args[0])) + print("args[0]", args[0]) + print("type of args[1]", type(args[1])) # Is a list no matter what!!! Like even if we pass a torch.tensor + print("args[1]", args[1]) + + # indices = args[1] # TODO do something like this! + indices = [2,3] + + # index = self.env[node.args[1]] # TODO + return self.block_builder.emit(relax.op.index_tensor(args[0], indices)) + # return self.block_builder.emit(relax.op.index_tensor(x, indices)) + def _permute(self, node: fx.Node) -> relax.Var: import torch # type: ignore diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 26121ecdea10..6794e4fdb657 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -380,6 +380,7 @@ def create_convert_map( "flatten.using_ints": self._flatten, "flip.default": self._flip, "gather.default": self._gather, + "index.Tensor": self._index_tensor, "permute.default": self._permute, "repeat.default": self._repeat, "select.int": self._select, @@ -505,6 +506,7 @@ def from_exported_program( assert ( func_name in self.convert_map ), f"Unsupported function type {func_name}" + print("found function", func_name, "!!!!!!!!!!!!!!!!!!!!!!!!!!!!") self.env[node] = self.convert_map[func_name](node) else: raise ValueError(f"Unsupported op {node.op}") diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py index 97f18a239640..f81c5448a024 100644 --- a/python/tvm/relax/op/__init__.py +++ b/python/tvm/relax/op/__init__.py @@ -94,6 +94,7 @@ flip, gather_elements, gather_nd, + index_tensor, layout_transform, one_hot, permute_dims, diff --git a/python/tvm/relax/op/manipulate.py b/python/tvm/relax/op/manipulate.py index 0f6e537ab3d6..0ffebb97e509 100644 --- a/python/tvm/relax/op/manipulate.py +++ b/python/tvm/relax/op/manipulate.py @@ -507,6 +507,17 @@ def gather_nd(data: Expr, indices: Expr, batch_dims: int = 0) -> Expr: """ return _ffi_api.gather_nd(data, indices, batch_dims) # type: ignore +def index_tensor(data: Expr, indices: List[int]) -> Expr: + """ + TODO docstring + """ + # TODO loosen those assertions! Need to handler lists of lists of lists etc. + assert isinstance(indices, list), "indices should be a list" + assert all(isinstance(i, int) for i in indices), "indices should be a list of integers, but got {}".format( + [type(i) for i in indices] + ) + return _ffi_api.index_tensor(data, indices) # type: ignore + def scatter_elements( data: Expr, indices: Expr, updates: Expr, axis: int = 0, reduction: str = "update" diff --git a/python/tvm/relax/op/op_attrs.py b/python/tvm/relax/op/op_attrs.py index 4658950f511a..b80ef3512d4d 100644 --- a/python/tvm/relax/op/op_attrs.py +++ b/python/tvm/relax/op/op_attrs.py @@ -182,3 +182,10 @@ class EinsumAttrs(Attrs): @tvm._ffi.register_object("relax.attrs.FlipAttrs") class FlipAttrs(Attrs): """Attributes for flip operator""" + + +# TODO is this needed? It looks like not all ops are here +@tvm._ffi.register_object("relax.attrs.IndexTensorAttrs") +class IndexTensorAttrs(Attrs): + """Attributes used in index_tensor operator""" + diff --git a/python/tvm/relax/transform/legalize_ops/manipulate.py b/python/tvm/relax/transform/legalize_ops/manipulate.py index 662d4e946b5f..8eeb9c026dd4 100644 --- a/python/tvm/relax/transform/legalize_ops/manipulate.py +++ b/python/tvm/relax/transform/legalize_ops/manipulate.py @@ -161,6 +161,10 @@ def te_gather_nd(data, indices, batch_dims): return bb.call_te(te_gather_nd, call.args[0], call.args[1], int(call.attrs.batch_dims)) +# TODO what does this do? +@register_legalize("relax.index_tensor") +def _index_tensor(bb: BlockBuilder, call: Call) -> Expr: + return bb.call_te(topi.index_tensor, call.args[0], call.attrs.indices) # TODO should I use primfunc_name_hint? @register_legalize("relax.scatter_elements") def _scatter_elements(bb: BlockBuilder, call: Call) -> Expr: diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py index b8605aa58a2e..1dbfdbc71b0e 100644 --- a/python/tvm/topi/transform.py +++ b/python/tvm/topi/transform.py @@ -1052,3 +1052,150 @@ def _apply_trilu(*indices): return tvm.tir.Select(check_position, value, tvm.tir.const(0, data.dtype)) return te.compute(data.shape, _apply_trilu, name="trilu", tag=topi.tag.ELEMWISE) + + +def index_tensor(data, indices): + """ TODO docstring + Replicate data[indices] using only: + - basic indexing on data + - torch.index_select + - concatenation/stack + - broadcasting + … and no advanced indexing. + + Approach for multiple advanced indices: broadcast and loop + + Approach for single advanced index: + 1. Convert the nested Python list to a LongTensor. + 2. Remove exactly one leading dimension of size=1, if present. (Matches PyTorch's shape rule.) + 3. Flatten -> fix negative indices -> index_select -> reshape. + """ + + print("we have reached index_tensor in topi !") + return topi.sum(data, axis=[0]) # TODO used for debugging, remove + + # TODO convert everything below in topi-syntax + + def _prod(shape): + """Compute the product of all dimensions in 'shape'.""" + out = 1 + for s in shape: + out *= s + return out + + def _broadcast_shapes(shapes): + # equivalent to `return torch.broadcast_shapes(*shapes)`, but can't find how to have broadcast_shapes in TVM. + # TODO Try to understand what exported translator does when I pass broadcast_shapes, since cuda_export_index_broadcat_shape + """ + Re-implementation of torch.broadcast_shapes since not sure how to call torch.broadcast_shapes(*shapes) in topi + """ + max_ndim = max(len(s) for s in shapes) + rev_shapes = [s[::-1] for s in shapes] + out = [] + for i in range(max_ndim): + dim_size = 1 + for rsh in rev_shapes: + if i < len(rsh): + s_ = rsh[i] + if s_ != 1 and dim_size != 1 and s_ != dim_size: + raise ValueError(f"Incompatible shapes for broadcast: {shapes}") + dim_size = max(dim_size, s_) + out.append(dim_size) + out.reverse() + return tuple(out) + + def _is_multiple_indices(indices): + """ + Decide if 'indices' is multiple parallel indices vs. a single advanced index. + - If the top-level is a list/tuple and len(indices) > 1, interpret it as multiple indices. + - Otherwise, it's a single advanced index. + """ + if isinstance(indices, (list, tuple)): + if len(indices) > 1: + return True + return False + + + # ----------------------------------------------------------- + # CASE B: multiple advanced indices + # ----------------------------------------------------------- + if _is_multiple_indices(indices): + # e.g. data[[0,2],[1,3]] => separate indices for dim=0, dim=1 + idx_list = [] + for sub_i in indices: + idx_list.append(torch.tensor(sub_i, dtype=torch.long)) + + # 1) Broadcast them to a common shape B + shapes = [x.shape for x in idx_list] + B = _broadcast_shapes(shapes) + + # 2) Expand each index to that shape + for i in range(len(idx_list)): + idx_list[i] = idx_list[i].expand(B) + + k = len(idx_list) # number of advanced dims + leftover_dims = data.shape[k:] # leftover dims after those k + + M = _prod(B) + # 3) Flatten each index to length=M + for i in range(k): + idx_list[i] = idx_list[i].reshape(M) + + # 4) Enumerate each broadcasted coordinate => basic scalar indexing + slices = [] + for n in range(M): + out_slice = data + for i in range(k): + scalar_idx = idx_list[i][n].item() + # handle negative indexing if you want: + if scalar_idx < 0: + scalar_idx += data.shape[i] + out_slice = out_slice[scalar_idx] + slices.append(out_slice.unsqueeze(0)) # shape [1, leftover_dims] + + # 5) Concatenate -> shape [M, leftover_dims] + stacked = torch.cat(slices, dim=0) + # 6) Reshape -> [B, leftover_dims] + final_shape = list(B) + list(leftover_dims) + result = stacked.view(*final_shape) + return result + + # ----------------------------------------------------------- + # CASE A: single advanced index + # ----------------------------------------------------------- + else: + # 1) Convert the nested Python list -> a LongTensor + # This is allowed. It's not advanced indexing on 'data', + # just building an index tensor from a Python list. + idx_t = torch.tensor(indices, dtype=torch.long) + + # 2) If there's at least one dimension and the first dimension is size=1, + # remove exactly one leading dim. + # (This matches PyTorch's "merge away the top-level [1]" rule for a single advanced index.) + if idx_t.dim() > 0 and idx_t.shape[0] == 1: + idx_t = idx_t.squeeze(0) # remove exactly one leading dim + + # 3) Flatten -> fix negative indices -> index_select + flattened = idx_t.reshape(-1) + # fix negative indices if desired: + # for i in range(flattened.size(0)): + # if flattened[i] < 0: + # flattened[i] = flattened[i] + data.shape[0] + + # we can do this in a vectorized manner if needed: + # but for brevity, let's skip negative idx correction or assume they're in range [0, data.shape[0]-1] + + # picked = torch.index_select(data, dim=0, index=flattened) + picked = take(a=data, indices=flattened, axis=0) + + + # leftover dims + leftover_dims = data.shape[1:] + # final shape = idx_t.shape + leftover_dims + adv_shape = idx_t.shape + final_shape = list(adv_shape) + list(leftover_dims) + result = picked.view(*final_shape) + return result + + + diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index cb738db363ee..6f238107f010 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -30,6 +30,7 @@ #include #include +#include "tvm/relax/type.h" // kUnknownNDim #include "tvm/runtime/data_type.h" namespace tvm { @@ -474,6 +475,91 @@ TVM_REGISTER_OP("relax.flatten") .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", Bool(true)); +/* relax.index_tensor */ +TVM_REGISTER_NODE_TYPE(IndexTensorAttrs); + +Expr index_tensor(Expr x, Array indices) { + auto attrs = make_object(); + attrs->indices = std::move(indices); + + static const Op& op = Op::Get("relax.index_tensor"); + return Call(op, {std::move(x)}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.index_tensor").set_body_typed(index_tensor); + +// TODO understand every line here? Is this all correct? +StructInfo InferStructInfoIndexTensor(const Call& call, const BlockBuilder& ctx) { + CheckNumArguments(call, ctx); + TensorStructInfo data_sinfo = GetInputTensorStructInfo(call, 0, ctx); + + // TODO the commented out checks below fail, understand why! + // // StructInfo inference when the index is a PrimValue is equivalent + // // to that of a scalar (0-d) tensor. + // TensorStructInfo indices_sinfo = [&]() { + // auto arg = call->args[0]; // TODO changed this from 1 to 0, is that ok? + // auto sinfo = GetStructInfo(arg); + // // TODO update the condition below. The indices argument should always be a tensor, it cannot be + // // a scalar value + // if (auto tensor_sinfo = sinfo.as()) { + // return tensor_sinfo.value(); + // } else if (auto prim_sinfo = sinfo.as()) { + // return TensorStructInfo(ShapeExpr(Array{}), prim_sinfo->dtype); + // } else { + // ctx->ReportFatal(Diagnostic::Error(call) + // << "Operator " << call->op << " requires the indices argument to be " + // << "either a tensor or a scalar value. " + // << "However, argument " << arg << " has struct info " << sinfo); + // // Unreachable, but [[noreturn]] attribute on virtual function + // // `ReportFatal` is insufficient to silence -Wreturn-type, as + // // child class might not be [[noreturn]]. + // return TensorStructInfo(); + // } + // }(); + + // if (indices_sinfo->IsUnknownDtype()) { + // // TODO(tvm-team): Do we have an equivalent of `ctx->ReportFatal` for warning? + // LOG(WARNING) << "Data type of indice has not been specified. Assume it has an integer type."; + // } else if (!(indices_sinfo->dtype.is_int() || indices_sinfo->dtype.is_uint())) { + // ctx->ReportFatal( + // Diagnostic::Error(call) + // << "Index Tensor op requires the input indices to have integer dtype. However, the " + // "given indices dtype is " + // << indices_sinfo->dtype); + // } + + // if (data_sinfo->IsUnknownNdim() || indices_sinfo->IsUnknownNdim()) { + // return TensorStructInfo(data_sinfo->dtype, kUnknownNDim, data_sinfo->vdevice); + // } + + // const auto* attrs = call->attrs.as(); + + // const auto* data_shape = data_sinfo->shape.as(); + // const auto* indices_shape = indices_sinfo->shape.as(); + // if (data_shape == nullptr || indices_shape == nullptr) { + // return TensorStructInfo(data_sinfo->dtype, indices_sinfo->ndim + data_sinfo->ndim - 1, + // data_sinfo->vdevice); + // } + + // TODO can we do better than kUnknownNDim, and instead do something like this for the output + // shape? Array output_shape; for (int i = 0; i < data_sinfo->ndim; i++) { + // if (i == axis) { + // for (int j = 0; j < indices_sinfo->ndim; j++) + // output_shape.push_back(indices_shape->values[j]); + // } else { + // output_shape.push_back(data_shape->values[i]); + // } + // } + return TensorStructInfo(data_sinfo->dtype, kUnknownNDim, data_sinfo->vdevice); +} + +TVM_REGISTER_OP("relax.index_tensor") + .set_attrs_type() + .set_num_inputs(1) + .add_argument("x", "Tensor", "The input tensor.") + .set_attr("FInferStructInfo", InferStructInfoIndexTensor) + .set_attr("FPurity", Bool(true)); + /* relax.layout_transform */ TVM_REGISTER_NODE_TYPE(LayoutTransformAttrs); diff --git a/src/relax/op/tensor/manipulate.h b/src/relax/op/tensor/manipulate.h index 1a0c7ddbc76c..8f88d40ccdb4 100644 --- a/src/relax/op/tensor/manipulate.h +++ b/src/relax/op/tensor/manipulate.h @@ -200,6 +200,19 @@ Expr gather_elements(Expr data, Expr indices, int axis = 0); */ Expr gather_nd(Expr data, Expr indices, int batch_dims = 0); +/*! // TODO update this comment + * \brief Gather values from a tensor using N-dimensional indices. + * \param data The input tensor. + * \param indices The indices tensor, must have integer type. + * \return The computed result. + * + * \note For batch_dims > 0, the first batch_dims dimensions of data and indices must be equal. + * The last dimension of indices indicates the depth of each index vector. + * The output shape is batch_dims + indices.shape[:-1] + data.shape[batch_dims + + * indices.shape[-1]:] + */ +Expr index_tensor(Expr data, Array indices); + /*! * \brief Scatter updates into an array according to indices. * \param data The input tensor. From 77dbc3b1a8b7ff7e04ca30fdfcc92c9903248a8f Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Mon, 7 Apr 2025 12:06:28 -0400 Subject: [PATCH 034/105] testing --- python/tvm/topi/transform.py | 235 ++++++++++++++++------------------- 1 file changed, 106 insertions(+), 129 deletions(-) diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py index 1dbfdbc71b0e..791d56cbcb44 100644 --- a/python/tvm/topi/transform.py +++ b/python/tvm/topi/transform.py @@ -1055,7 +1055,11 @@ def _apply_trilu(*indices): def index_tensor(data, indices): - """ TODO docstring + """ TODO docstring + - If 'indices' is a list/tuple of length > 1, we interpret that as multiple advanced indices, + and implement with topi.adv_index (plus negative-index correction if desired). + - Otherwise, interpret 'indices' as a single advanced index, and implement with topi.take. + Replicate data[indices] using only: - basic indexing on data - torch.index_select @@ -1070,132 +1074,105 @@ def index_tensor(data, indices): 2. Remove exactly one leading dimension of size=1, if present. (Matches PyTorch's shape rule.) 3. Flatten -> fix negative indices -> index_select -> reshape. """ + # The typical pattern is to define the new output via te.compute, + # with a lambda that describes the element-wise operation. + return te.compute( + data.shape, + lambda *indices: data(*indices) + tvm.tir.const(1, data.dtype), + name="dummy_add_one", + # For a simple element-wise operator, you can use tag=topi.tag.ELEMWISE + tag="elemwise", + ) - print("we have reached index_tensor in topi !") - return topi.sum(data, axis=[0]) # TODO used for debugging, remove - - # TODO convert everything below in topi-syntax - - def _prod(shape): - """Compute the product of all dimensions in 'shape'.""" - out = 1 - for s in shape: - out *= s - return out - - def _broadcast_shapes(shapes): - # equivalent to `return torch.broadcast_shapes(*shapes)`, but can't find how to have broadcast_shapes in TVM. - # TODO Try to understand what exported translator does when I pass broadcast_shapes, since cuda_export_index_broadcat_shape - """ - Re-implementation of torch.broadcast_shapes since not sure how to call torch.broadcast_shapes(*shapes) in topi - """ - max_ndim = max(len(s) for s in shapes) - rev_shapes = [s[::-1] for s in shapes] - out = [] - for i in range(max_ndim): - dim_size = 1 - for rsh in rev_shapes: - if i < len(rsh): - s_ = rsh[i] - if s_ != 1 and dim_size != 1 and s_ != dim_size: - raise ValueError(f"Incompatible shapes for broadcast: {shapes}") - dim_size = max(dim_size, s_) - out.append(dim_size) - out.reverse() - return tuple(out) - - def _is_multiple_indices(indices): - """ - Decide if 'indices' is multiple parallel indices vs. a single advanced index. - - If the top-level is a list/tuple and len(indices) > 1, interpret it as multiple indices. - - Otherwise, it's a single advanced index. - """ - if isinstance(indices, (list, tuple)): - if len(indices) > 1: - return True - return False - - - # ----------------------------------------------------------- - # CASE B: multiple advanced indices - # ----------------------------------------------------------- - if _is_multiple_indices(indices): - # e.g. data[[0,2],[1,3]] => separate indices for dim=0, dim=1 - idx_list = [] - for sub_i in indices: - idx_list.append(torch.tensor(sub_i, dtype=torch.long)) - - # 1) Broadcast them to a common shape B - shapes = [x.shape for x in idx_list] - B = _broadcast_shapes(shapes) - - # 2) Expand each index to that shape - for i in range(len(idx_list)): - idx_list[i] = idx_list[i].expand(B) - - k = len(idx_list) # number of advanced dims - leftover_dims = data.shape[k:] # leftover dims after those k - - M = _prod(B) - # 3) Flatten each index to length=M - for i in range(k): - idx_list[i] = idx_list[i].reshape(M) - - # 4) Enumerate each broadcasted coordinate => basic scalar indexing - slices = [] - for n in range(M): - out_slice = data - for i in range(k): - scalar_idx = idx_list[i][n].item() - # handle negative indexing if you want: - if scalar_idx < 0: - scalar_idx += data.shape[i] - out_slice = out_slice[scalar_idx] - slices.append(out_slice.unsqueeze(0)) # shape [1, leftover_dims] - - # 5) Concatenate -> shape [M, leftover_dims] - stacked = torch.cat(slices, dim=0) - # 6) Reshape -> [B, leftover_dims] - final_shape = list(B) + list(leftover_dims) - result = stacked.view(*final_shape) - return result - - # ----------------------------------------------------------- - # CASE A: single advanced index - # ----------------------------------------------------------- - else: - # 1) Convert the nested Python list -> a LongTensor - # This is allowed. It's not advanced indexing on 'data', - # just building an index tensor from a Python list. - idx_t = torch.tensor(indices, dtype=torch.long) - - # 2) If there's at least one dimension and the first dimension is size=1, - # remove exactly one leading dim. - # (This matches PyTorch's "merge away the top-level [1]" rule for a single advanced index.) - if idx_t.dim() > 0 and idx_t.shape[0] == 1: - idx_t = idx_t.squeeze(0) # remove exactly one leading dim - - # 3) Flatten -> fix negative indices -> index_select - flattened = idx_t.reshape(-1) - # fix negative indices if desired: - # for i in range(flattened.size(0)): - # if flattened[i] < 0: - # flattened[i] = flattened[i] + data.shape[0] - - # we can do this in a vectorized manner if needed: - # but for brevity, let's skip negative idx correction or assume they're in range [0, data.shape[0]-1] - - # picked = torch.index_select(data, dim=0, index=flattened) - picked = take(a=data, indices=flattened, axis=0) - - - # leftover dims - leftover_dims = data.shape[1:] - # final shape = idx_t.shape + leftover_dims - adv_shape = idx_t.shape - final_shape = list(adv_shape) + list(leftover_dims) - result = picked.view(*final_shape) - return result - - - + # return data + + # TODO uncomment + + + + # # Helper to fix negative indices: out_idx = where(idx<0, idx+dim_size, idx) + # def _fix_negatives(idx_t, dim_size): + # # idx_t, dim_size are tvm.te.Tensor or integers. + # # We'll broadcast if needed. We can do so by calling topi.where(...) with the condition + # # (idx_t < 0). + # # For static shape, `dim_size` could be int. For dynamic shape, dim_size might be a Tensor. + # # Suppose dim_size is int here. Then we can just do: + + # # TODO uncomment + # # zero_t = topi.full_like(idx_t, 0) + # # dim_size_t = topi.full_like(idx_t, dim_size) # broadcast if needed + # # return topi.where(topi.less(idx_t, zero_t), topi.add(idx_t, dim_size_t), idx_t) + # return topi.where(idx_t < 0, idx_t + dim_size, idx_t) + + # # --- Check whether indices is multiple advanced indices or single advanced index. --- + # if isinstance(indices, (list, tuple)) and len(indices) > 1: + # # ----------------------------------------------------------- + # # CASE B: multiple advanced indices + # # ----------------------------------------------------------- + # # Suppose each sub_i is a tvm.te.Tensor of integer type, indexing a separate dimension. + # # We want to broadcast them to a common shape (if not already), + # # fix negative indices, then use topi.adv_index. + # idx_list = list(indices) + + # # 1) Determine broadcast shape. For simplicity we can rely on `topi.adv_index` automatically + # # broadcasting the indices if they differ in shape. If you need explicit broadcasting, + # # you can do so via topi utilities (e.g. topi.broadcast_to). + # # Then fix negative indexing dimensionwise. + # # data.shape is e.g. [d0, d1, d2, ...], so for the i-th advanced index, dimension = data.shape[i]. + # # We fix negative indexing if desired: + # final_indices = [] + # for i, idx_t in enumerate(idx_list): + # # If we want negative fix, do it here: + # dim_size = data.shape[i] # a PrimExpr + # fixed = _fix_negatives(idx_t, dim_size) + # final_indices.append(fixed) + + # # 2) Use topi.adv_index + # # This will produce a new tensor with shape = broadcast of final_indices. + # result = topi.adv_index(data, final_indices) + # return result + + # else: + # # ----------------------------------------------------------- + # # CASE A: single advanced index + # # ----------------------------------------------------------- + # # We interpret 'indices' as a single integer-tensor for dimension=0 indexing. + # # So the result shape is [*indices_shape, leftover_dims], with leftover_dims = data.shape[1:]. + # # + # # Steps, paralleling the Python: + # # 1) If the first dimension of indices is 1, remove it => topi.squeeze if we want. + # # 2) Flatten => topi.reshape + # # 3) fix negative indices => topi.where + # # 4) gather => topi.take(..., axis=0) + # # 5) reshape => combine advanced dims + leftover dims + # idx_t = indices if isinstance(indices, te.Tensor) else indices[0] + + # # Possibly remove leading dimension if shape[0]==1: + # if len(idx_t.shape) > 0: + # first_dim = idx_t.shape[0] + # if isinstance(first_dim, int) and first_dim == 1: + # # topi.squeeze can remove exactly one axis: + # idx_t = topi.squeeze(idx_t, axis=[0]) + # else: + # # If we suspect it's dynamic, we can check with a small schedule or approach, + # # but here's the naive approach: we skip if the dimension is unknown + # pass + + # # Flatten + # flattened = topi.reshape(idx_t, (-1,)) + + # # fix negative indexing + # # data.shape[0] is batch dimension + # fixed = _fix_negatives(flattened, data.shape[0]) + + # # gather => topi.take + # # out shape = [len_of_fixed] + leftover_dims + # picked = topi.take(data, fixed, axis=0) + + # # final reshape => idx_t original shape (after squeeze) + leftover + # # we can get idx_t's shape with topi.shape if dynamic, or known statically + # adv_shape = tuple(idx_t.shape) # or topi.shape(idx_t) if dynamic + # leftover_dims = data.shape[1:] + # final_shape = adv_shape + leftover_dims + # result = topi.reshape(picked, final_shape) + # return result From 557885be6127cbc895f0dae9ea6866f83928d874 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 13 Apr 2025 16:16:32 -0400 Subject: [PATCH 035/105] able to get an output with dummy relax.op.collapse_sum_like --- .../torch/base_fx_graph_translator.py | 22 +++++++++++++++++-- python/tvm/relax/op/manipulate.py | 12 +++++----- python/tvm/topi/broadcast.py | 3 ++- 3 files changed, 28 insertions(+), 9 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index c26dc89f2348..f51ccd76092a 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -93,12 +93,16 @@ def _retrieve_args(self, node): from torch import fx if isinstance(node, fx.Node): + print("isinstance(node, fx.Node)") return self.env[node] elif isinstance(node, tuple): + print("isinstance(node, tuple) of length", len(node)) return tuple(self._retrieve_args(x) for x in node) elif isinstance(node, list): + print("isinstance(node, list) of length", len(node)) return [self._retrieve_args(x) for x in node] elif isinstance(node, dict): + print("isinstance(node, dict)") return {self._retrieve_args(k): self._retrieve_args(v) for k, v in node.items()} else: return node @@ -1059,9 +1063,12 @@ def _gather(self, node: fx.Node) -> relax.Var: return self.block_builder.emit(relax.op.gather_elements(x, index, axis=dim)) def _index_tensor(self, node: fx.Node) -> relax.Var: + # TODO should I be using _binary_op() ? + # ? x = self.env[node.args[0]] # indices = node.args[1] args = self.retrieve_args(node) + print("node: fx.Node passed to index_tensor:") print("len of args", len(args)) print("type of args[0]", type(args[0])) print("args[0]", args[0]) @@ -1069,10 +1076,21 @@ def _index_tensor(self, node: fx.Node) -> relax.Var: print("args[1]", args[1]) # indices = args[1] # TODO do something like this! - indices = [2,3] + # indices = [2,3] + indices = args[1][0] + print("type of indices", type(indices)) + # print("indices:") + # args_indices = self.retrieve_args(indices) + # print("len of args_indices", len(args_indices)) + # print("type of args_indices[0]", type(args_indices[0])) + # print("args_indices[0]", args_indices[0]) + # print("type of args_indices[1]", type(args_indices[1])) # Is a list no matter what!!! Like even if we pass a torch.tensor + # print("args_indices[1]", args_indices[1]) + # index = self.env[node.args[1]] # TODO - return self.block_builder.emit(relax.op.index_tensor(args[0], indices)) + # return self.block_builder.emit(relax.op.index_tensor(args[0], indices)) # TODO revert! removed to test collapse sum like + return self.block_builder.emit(relax.op.collapse_sum_like(args[0], indices)) # return self.block_builder.emit(relax.op.index_tensor(x, indices)) def _permute(self, node: fx.Node) -> relax.Var: diff --git a/python/tvm/relax/op/manipulate.py b/python/tvm/relax/op/manipulate.py index 0ffebb97e509..91611c7b6a19 100644 --- a/python/tvm/relax/op/manipulate.py +++ b/python/tvm/relax/op/manipulate.py @@ -507,16 +507,16 @@ def gather_nd(data: Expr, indices: Expr, batch_dims: int = 0) -> Expr: """ return _ffi_api.gather_nd(data, indices, batch_dims) # type: ignore -def index_tensor(data: Expr, indices: List[int]) -> Expr: +def index_tensor(data: Expr, indices: Expr) -> Expr: """ TODO docstring """ # TODO loosen those assertions! Need to handler lists of lists of lists etc. - assert isinstance(indices, list), "indices should be a list" - assert all(isinstance(i, int) for i in indices), "indices should be a list of integers, but got {}".format( - [type(i) for i in indices] - ) - return _ffi_api.index_tensor(data, indices) # type: ignore + # assert isinstance(indices, list), f"indices should be a list, but is a {type(indices)}. Data is a {type(data)}" + # assert all(isinstance(i, int) for i in indices), "indices should be a list of integers, but got {}".format( + # [type(i) for i in indices] + # ) + return _ffi_api.add(data, indices) # type: ignore def scatter_elements( diff --git a/python/tvm/topi/broadcast.py b/python/tvm/topi/broadcast.py index 2b350ff817d9..597c7f24d4c2 100644 --- a/python/tvm/topi/broadcast.py +++ b/python/tvm/topi/broadcast.py @@ -56,7 +56,8 @@ def add(lhs, rhs): Returns Expr if both operands are Expr. Otherwise returns Tensor. """ - return _cpp.add(lhs, rhs) + return lhs + # return _cpp.add(lhs, rhs) # TODO revert def subtract(lhs, rhs): From 6b664cc7d3dc1162b3d1039e021d83f643c28d3e Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 13 Apr 2025 16:38:15 -0400 Subject: [PATCH 036/105] whether I can use collapse sum like TWO dpeends on whether I legalize in legalize_ops/manipulate.py --- .../torch/base_fx_graph_translator.py | 2 +- python/tvm/relax/op/__init__.py | 1 + python/tvm/relax/op/manipulate.py | 20 ++++++++ .../transform/legalize_ops/manipulate.py | 7 ++- python/tvm/script/ir_builder/relax/ir.py | 4 +- src/relax/op/tensor/manipulate.cc | 46 +++++++++++++++++++ src/relax/op/tensor/manipulate.h | 9 ++++ 7 files changed, 86 insertions(+), 3 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index f51ccd76092a..c3c369486a4d 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1090,7 +1090,7 @@ def _index_tensor(self, node: fx.Node) -> relax.Var: # index = self.env[node.args[1]] # TODO # return self.block_builder.emit(relax.op.index_tensor(args[0], indices)) # TODO revert! removed to test collapse sum like - return self.block_builder.emit(relax.op.collapse_sum_like(args[0], indices)) + return self.block_builder.emit(relax.op.collapse_sum_like_TWO(args[0], indices)) # return self.block_builder.emit(relax.op.index_tensor(x, indices)) def _permute(self, node: fx.Node) -> relax.Var: diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py index f81c5448a024..19794340a9c0 100644 --- a/python/tvm/relax/op/__init__.py +++ b/python/tvm/relax/op/__init__.py @@ -87,6 +87,7 @@ from .manipulate import ( broadcast_to, collapse_sum_like, + collapse_sum_like_TWO, collapse_sum_to, concat, expand_dims, diff --git a/python/tvm/relax/op/manipulate.py b/python/tvm/relax/op/manipulate.py index 91611c7b6a19..5e3d21623146 100644 --- a/python/tvm/relax/op/manipulate.py +++ b/python/tvm/relax/op/manipulate.py @@ -300,6 +300,26 @@ def collapse_sum_like(data: Expr, collapse_target: Expr) -> Expr: return _ffi_api.collapse_sum_like(data, collapse_target) # type: ignore +def collapse_sum_like_TWO(data: Expr, collapse_target: Expr) -> Expr: + """Return a summation of data to the shape of collapse_target. + + For details, please see relax.op.collapse_sum_to. + + Parameters + ---------- + data : relax.Expr + The input tensor. + + collapse_target : relax.Expr + The tensor whose shape is the shape to collapse to. + + Returns + ------- + result : relax.Expr + The result tensor after summation. + """ + return _ffi_api.collapse_sum_like_TWO(data, collapse_target) # type: ignore + def collapse_sum_to(data: Expr, shape: Union[Tuple[PrimExprLike], Expr]) -> Expr: """Return a summation of data to the given shape. diff --git a/python/tvm/relax/transform/legalize_ops/manipulate.py b/python/tvm/relax/transform/legalize_ops/manipulate.py index 8eeb9c026dd4..371f4b1f7bb3 100644 --- a/python/tvm/relax/transform/legalize_ops/manipulate.py +++ b/python/tvm/relax/transform/legalize_ops/manipulate.py @@ -45,10 +45,15 @@ def reshape_call_te(bb: BlockBuilder, call: Call): register_legalize("relax.broadcast_to", _reshape(topi.broadcast_to, "broadcast_to")) register_legalize("relax.reshape", _reshape(topi.reshape, "reshape")) -register_legalize( +register_legalize( # TODO how about _TWO ???? "relax.collapse_sum_like", _reshape(topi.collapse_sum, "collapse_sum", is_collapse_sum_like=True), ) +# register_legalize( # TODO how about _TWO ???? +# "relax.collapse_sum_like_TWO", +# _reshape(topi.collapse_sum, "collapse_sum", is_collapse_sum_like=True), +# ) + register_legalize("relax.collapse_sum_to", _reshape(topi.collapse_sum, "collapse_sum")) diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index ddc534cf6086..afce5c86750a 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -70,6 +70,7 @@ ceil, clip, collapse_sum_like, + collapse_sum_like_TWO, # TODO is this necessary? collapse_sum_to, concat, cos, @@ -734,7 +735,8 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "call_builtin_with_ctx", "ceil", "clip", - "collapse_sum_like", + collapse_sum_like, + "collapse_sum_like_TWO", # TODO is this necessary? "collapse_sum_to", "concat", "cos", diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index 6f238107f010..df246c9eab2e 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -1325,6 +1325,52 @@ TVM_REGISTER_OP("relax.collapse_sum_like") .set_attr("FInferStructInfo", InferStructInfoCollapseSumLike) .set_attr("FPurity", Bool(true)); +/* relax.collapse_sum_like_TWO */ +Expr collapse_sum_like_TWO(Expr data, Expr collapse_target) { + static const Op& op = Op::Get("relax.collapse_sum_like_TWO"); + return Call(op, {std::move(data), std::move(collapse_target)}, Attrs(), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.collapse_sum_like_TWO").set_body_typed(collapse_sum_like_TWO); + +StructInfo InferStructInfoCollapseSumLikeTWO(const Call& call, const BlockBuilder& ctx) { + Array input_sinfo = GetInputTensorStructInfo(call, ctx); + TensorStructInfo data_sinfo = input_sinfo[0]; + TensorStructInfo collapse_target_sinfo = input_sinfo[1]; + + DataType output_dtype = data_sinfo->dtype; + + Optional> data_shape_value; + if (data_sinfo->shape.defined()) { + data_shape_value = GetStructInfoAs(data_sinfo->shape.value())->values; + } + Optional> collapse_target_shape_value; + if (collapse_target_sinfo->shape.defined()) { + collapse_target_shape_value = + GetStructInfoAs(collapse_target_sinfo->shape.value())->values; + } + + if (data_shape_value.defined() && collapse_target_shape_value.defined()) { + CheckCollapseShape(call, ctx, data_shape_value.value(), collapse_target_shape_value.value()); + } + + if (collapse_target_sinfo->shape.defined()) { + return TensorStructInfo(collapse_target_sinfo->shape.value(), output_dtype, + collapse_target_sinfo->vdevice); + } else { + return TensorStructInfo(output_dtype, collapse_target_sinfo->ndim, + collapse_target_sinfo->vdevice); + } +} + +TVM_REGISTER_OP("relax.collapse_sum_like_TWO") + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("collapse_target", "Tensor", + "The tensor whose shape is the shape to collapse to.") + .set_attr("FInferStructInfo", InferStructInfoCollapseSumLikeTWO) + .set_attr("FPurity", Bool(true)); + /* relax.collapse_sum_to */ Expr collapse_sum_to(Expr data, Expr shape) { static const Op& op = Op::Get("relax.collapse_sum_to"); diff --git a/src/relax/op/tensor/manipulate.h b/src/relax/op/tensor/manipulate.h index 8f88d40ccdb4..f22719faab4c 100644 --- a/src/relax/op/tensor/manipulate.h +++ b/src/relax/op/tensor/manipulate.h @@ -127,6 +127,15 @@ Expr squeeze(Expr x, Optional> axis); */ Expr collapse_sum_like(Expr data, Expr collapse_target); +/*! + * \brief Return a summation of data to the shape of collapse_target. + * For details, please see the operator `relax.collapse_sum_to`. + * \param data The input tensor. + * \param collapse_target The tensor whose shape is the shape to collapse to. + * \return The result tensor after summation. + */ +Expr collapse_sum_like_TWO(Expr data, Expr collapse_target); + /*! * \brief Return a summation of data to the given shape. * collapse_sum_to is intended as the backward operator of broadcast_to and From a4116621882bac510f16a5d3a58930ca118b5685 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 13 Apr 2025 16:43:02 -0400 Subject: [PATCH 037/105] able to call relax.collapse_sum_like_TWO --- python/tvm/relax/transform/legalize_ops/manipulate.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/python/tvm/relax/transform/legalize_ops/manipulate.py b/python/tvm/relax/transform/legalize_ops/manipulate.py index 371f4b1f7bb3..c934af8ed416 100644 --- a/python/tvm/relax/transform/legalize_ops/manipulate.py +++ b/python/tvm/relax/transform/legalize_ops/manipulate.py @@ -45,14 +45,14 @@ def reshape_call_te(bb: BlockBuilder, call: Call): register_legalize("relax.broadcast_to", _reshape(topi.broadcast_to, "broadcast_to")) register_legalize("relax.reshape", _reshape(topi.reshape, "reshape")) -register_legalize( # TODO how about _TWO ???? +register_legalize( "relax.collapse_sum_like", _reshape(topi.collapse_sum, "collapse_sum", is_collapse_sum_like=True), ) -# register_legalize( # TODO how about _TWO ???? -# "relax.collapse_sum_like_TWO", -# _reshape(topi.collapse_sum, "collapse_sum", is_collapse_sum_like=True), -# ) +register_legalize( # TODO try to call a topi directly? + "relax.collapse_sum_like_TWO", + _reshape(topi.collapse_sum, "collapse_sum", is_collapse_sum_like=True), +) register_legalize("relax.collapse_sum_to", _reshape(topi.collapse_sum, "collapse_sum")) From c418fc54d8444c7e3c2f418780c1bd8724795928 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 13 Apr 2025 16:46:52 -0400 Subject: [PATCH 038/105] _TWO works with regualar registration --- .../tvm/relax/transform/legalize_ops/manipulate.py | 14 ++++++++++---- python/tvm/topi/transform.py | 3 +++ 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/python/tvm/relax/transform/legalize_ops/manipulate.py b/python/tvm/relax/transform/legalize_ops/manipulate.py index c934af8ed416..1e9257b14ba5 100644 --- a/python/tvm/relax/transform/legalize_ops/manipulate.py +++ b/python/tvm/relax/transform/legalize_ops/manipulate.py @@ -49,10 +49,16 @@ def reshape_call_te(bb: BlockBuilder, call: Call): "relax.collapse_sum_like", _reshape(topi.collapse_sum, "collapse_sum", is_collapse_sum_like=True), ) -register_legalize( # TODO try to call a topi directly? - "relax.collapse_sum_like_TWO", - _reshape(topi.collapse_sum, "collapse_sum", is_collapse_sum_like=True), -) + +# register_legalize( # TODO try to call a topi directly? +# "relax.collapse_sum_like_TWO", +# _reshape(topi.collapse_sum, "collapse_sum", is_collapse_sum_like=True), +# ) + +@register_legalize("relax.collapse_sum_like_TWO") +def _index_tensor(bb: BlockBuilder, call: Call) -> Expr: + return bb.call_te(topi.index_tensor, call.args[0], call.args[1]) # TODO should I use primfunc_name_hint? + register_legalize("relax.collapse_sum_to", _reshape(topi.collapse_sum, "collapse_sum")) diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py index 1dbfdbc71b0e..d191385294fb 100644 --- a/python/tvm/topi/transform.py +++ b/python/tvm/topi/transform.py @@ -1054,6 +1054,9 @@ def _apply_trilu(*indices): return te.compute(data.shape, _apply_trilu, name="trilu", tag=topi.tag.ELEMWISE) +def collapse_sum_like_TWO(data, indices): + return topi.sum(data, axis=[0]) + def index_tensor(data, indices): """ TODO docstring Replicate data[indices] using only: From c497eafff3d3373c6436a47607c95126e2e4b1b0 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 13 Apr 2025 16:49:57 -0400 Subject: [PATCH 039/105] both topi options work --- python/tvm/topi/transform.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py index d191385294fb..c18e6cf276f2 100644 --- a/python/tvm/topi/transform.py +++ b/python/tvm/topi/transform.py @@ -1055,7 +1055,9 @@ def _apply_trilu(*indices): def collapse_sum_like_TWO(data, indices): - return topi.sum(data, axis=[0]) + return data + # return indices # both work! + # return topi.sum(data, axis=[0]) # both work! def index_tensor(data, indices): """ TODO docstring From 7438ce59a3265706950751bf2566f4c02ae6283f Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 13 Apr 2025 16:56:36 -0400 Subject: [PATCH 040/105] can still output from _TWO after merge index.Tensor and index.Tensor3 --- .../frontend/torch/base_fx_graph_translator.py | 15 +-------------- 1 file changed, 1 insertion(+), 14 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index a412be56417f..0376cc9afc29 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1075,18 +1075,12 @@ def _gather(self, node: fx.Node) -> relax.Var: return self.block_builder.emit(relax.op.gather_elements(x, index, axis=dim)) def _index_tensor(self, node: fx.Node) -> relax.Var: -<<<<<<< HEAD # TODO should I be using _binary_op() ? # ? x = self.env[node.args[0]] # indices = node.args[1] args = self.retrieve_args(node) print("node: fx.Node passed to index_tensor:") -======= -# ? x = self.env[node.args[0]] - # indices = node.args[1] - args = self.retrieve_args(node) ->>>>>>> index.Tensor3 print("len of args", len(args)) print("type of args[0]", type(args[0])) print("args[0]", args[0]) @@ -1094,7 +1088,6 @@ def _index_tensor(self, node: fx.Node) -> relax.Var: print("args[1]", args[1]) # indices = args[1] # TODO do something like this! -<<<<<<< HEAD # indices = [2,3] indices = args[1][0] print("type of indices", type(indices)) @@ -1110,13 +1103,7 @@ def _index_tensor(self, node: fx.Node) -> relax.Var: # index = self.env[node.args[1]] # TODO # return self.block_builder.emit(relax.op.index_tensor(args[0], indices)) # TODO revert! removed to test collapse sum like return self.block_builder.emit(relax.op.collapse_sum_like_TWO(args[0], indices)) -======= - indices = [2,3] - - # index = self.env[node.args[1]] # TODO - return self.block_builder.emit(relax.op.index_tensor(args[0], indices)) ->>>>>>> index.Tensor3 - # return self.block_builder.emit(relax.op.index_tensor(x, indices)) + # return self.block_builder.emit(relax.op.index_tensor(args[0], indices)) # TODO switch the above to this def _permute(self, node: fx.Node) -> relax.Var: import torch # type: ignore From 9d6270bee6d8876db1fec4dd0019d7351c49e066 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 13 Apr 2025 17:02:21 -0400 Subject: [PATCH 041/105] still able to output after building (hadn't built after merge) --- python/tvm/topi/transform.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py index 4d8a41fd9ad6..b093be86aa0b 100644 --- a/python/tvm/topi/transform.py +++ b/python/tvm/topi/transform.py @@ -1055,6 +1055,9 @@ def _apply_trilu(*indices): def collapse_sum_like_TWO(data, indices): + print("IN TOPI: !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") + print(data) + print(indices) return data # return indices # both work! # return topi.sum(data, axis=[0]) # both work! From 29fbc0185fd77a61b83b06cf732ba7834a3cec8e Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 13 Apr 2025 17:08:56 -0400 Subject: [PATCH 042/105] must do a topi op in transform.py to get an output --- .../tvm/relax/transform/legalize_ops/manipulate.py | 13 +++++++------ python/tvm/topi/transform.py | 5 ++--- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/python/tvm/relax/transform/legalize_ops/manipulate.py b/python/tvm/relax/transform/legalize_ops/manipulate.py index 1e9257b14ba5..af4840104850 100644 --- a/python/tvm/relax/transform/legalize_ops/manipulate.py +++ b/python/tvm/relax/transform/legalize_ops/manipulate.py @@ -50,15 +50,16 @@ def reshape_call_te(bb: BlockBuilder, call: Call): _reshape(topi.collapse_sum, "collapse_sum", is_collapse_sum_like=True), ) -# register_legalize( # TODO try to call a topi directly? -# "relax.collapse_sum_like_TWO", -# _reshape(topi.collapse_sum, "collapse_sum", is_collapse_sum_like=True), -# ) +# # TODO this correctly calls index_tensor! +# @register_legalize("relax.collapse_sum_like_TWO") +# def _index_tensor(bb: BlockBuilder, call: Call) -> Expr: +# return bb.call_te(topi.index_tensor, call.args[0], call.args[1]) # TODO should I use primfunc_name_hint? + +# TODO this correctly calls index_tensor! @register_legalize("relax.collapse_sum_like_TWO") def _index_tensor(bb: BlockBuilder, call: Call) -> Expr: - return bb.call_te(topi.index_tensor, call.args[0], call.args[1]) # TODO should I use primfunc_name_hint? - + return bb.call_te(topi.collapse_sum_like_TWO, call.args[0], call.args[1]) # TODO should I use primfunc_name_hint? register_legalize("relax.collapse_sum_to", _reshape(topi.collapse_sum, "collapse_sum")) diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py index b093be86aa0b..a838e980b209 100644 --- a/python/tvm/topi/transform.py +++ b/python/tvm/topi/transform.py @@ -1058,9 +1058,8 @@ def collapse_sum_like_TWO(data, indices): print("IN TOPI: !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") print(data) print(indices) - return data - # return indices # both work! - # return topi.sum(data, axis=[0]) # both work! + # return data # doesn't work + return topi.sum(data, axis=[0]) def index_tensor(data, indices): """ TODO docstring From ff88c5ab240a4b3c3320096825cfef3ced40ca4c Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 13 Apr 2025 17:34:58 -0400 Subject: [PATCH 043/105] error re. IndexTensorAttrs --- .../torch/base_fx_graph_translator.py | 2 +- python/tvm/relax/op/__init__.py | 1 - python/tvm/relax/op/manipulate.py | 22 --- .../transform/legalize_ops/manipulate.py | 14 +- python/tvm/script/ir_builder/relax/ir.py | 6 +- python/tvm/topi/transform.py | 24 ++- src/relax/op/tensor/manipulate.cc | 146 ++++-------------- src/relax/op/tensor/manipulate.h | 2 +- 8 files changed, 43 insertions(+), 174 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 0376cc9afc29..7cc390d5e67b 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1102,7 +1102,7 @@ def _index_tensor(self, node: fx.Node) -> relax.Var: # index = self.env[node.args[1]] # TODO # return self.block_builder.emit(relax.op.index_tensor(args[0], indices)) # TODO revert! removed to test collapse sum like - return self.block_builder.emit(relax.op.collapse_sum_like_TWO(args[0], indices)) + return self.block_builder.emit(relax.op.index_tensor(args[0], indices)) # return self.block_builder.emit(relax.op.index_tensor(args[0], indices)) # TODO switch the above to this def _permute(self, node: fx.Node) -> relax.Var: diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py index 19794340a9c0..f81c5448a024 100644 --- a/python/tvm/relax/op/__init__.py +++ b/python/tvm/relax/op/__init__.py @@ -87,7 +87,6 @@ from .manipulate import ( broadcast_to, collapse_sum_like, - collapse_sum_like_TWO, collapse_sum_to, concat, expand_dims, diff --git a/python/tvm/relax/op/manipulate.py b/python/tvm/relax/op/manipulate.py index a0997cc2085f..159e496fa16c 100644 --- a/python/tvm/relax/op/manipulate.py +++ b/python/tvm/relax/op/manipulate.py @@ -299,27 +299,6 @@ def collapse_sum_like(data: Expr, collapse_target: Expr) -> Expr: """ return _ffi_api.collapse_sum_like(data, collapse_target) # type: ignore - -def collapse_sum_like_TWO(data: Expr, collapse_target: Expr) -> Expr: - """Return a summation of data to the shape of collapse_target. - - For details, please see relax.op.collapse_sum_to. - - Parameters - ---------- - data : relax.Expr - The input tensor. - - collapse_target : relax.Expr - The tensor whose shape is the shape to collapse to. - - Returns - ------- - result : relax.Expr - The result tensor after summation. - """ - return _ffi_api.collapse_sum_like_TWO(data, collapse_target) # type: ignore - def collapse_sum_to(data: Expr, shape: Union[Tuple[PrimExprLike], Expr]) -> Expr: """Return a summation of data to the given shape. @@ -538,7 +517,6 @@ def index_tensor(data: Expr, indices: Expr) -> Expr: # ) return _ffi_api.index_tensor(data, indices) # type: ignore - def scatter_elements( data: Expr, indices: Expr, updates: Expr, axis: int = 0, reduction: str = "update" ): diff --git a/python/tvm/relax/transform/legalize_ops/manipulate.py b/python/tvm/relax/transform/legalize_ops/manipulate.py index af4840104850..d4aac7abc392 100644 --- a/python/tvm/relax/transform/legalize_ops/manipulate.py +++ b/python/tvm/relax/transform/legalize_ops/manipulate.py @@ -50,17 +50,6 @@ def reshape_call_te(bb: BlockBuilder, call: Call): _reshape(topi.collapse_sum, "collapse_sum", is_collapse_sum_like=True), ) -# # TODO this correctly calls index_tensor! -# @register_legalize("relax.collapse_sum_like_TWO") -# def _index_tensor(bb: BlockBuilder, call: Call) -> Expr: -# return bb.call_te(topi.index_tensor, call.args[0], call.args[1]) # TODO should I use primfunc_name_hint? - - -# TODO this correctly calls index_tensor! -@register_legalize("relax.collapse_sum_like_TWO") -def _index_tensor(bb: BlockBuilder, call: Call) -> Expr: - return bb.call_te(topi.collapse_sum_like_TWO, call.args[0], call.args[1]) # TODO should I use primfunc_name_hint? - register_legalize("relax.collapse_sum_to", _reshape(topi.collapse_sum, "collapse_sum")) @@ -173,10 +162,9 @@ def te_gather_nd(data, indices, batch_dims): return bb.call_te(te_gather_nd, call.args[0], call.args[1], int(call.attrs.batch_dims)) -# TODO what does this do? @register_legalize("relax.index_tensor") def _index_tensor(bb: BlockBuilder, call: Call) -> Expr: - return bb.call_te(topi.index_tensor, call.args[0], call.attrs.indices) # TODO should I use primfunc_name_hint? + return bb.call_te(topi.index_tensor, call.args[0], call.args[1]) # TODO should I use primfunc_name_hint? @register_legalize("relax.scatter_elements") def _scatter_elements(bb: BlockBuilder, call: Call) -> Expr: diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index afce5c86750a..e0ead3098d5a 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -70,7 +70,6 @@ ceil, clip, collapse_sum_like, - collapse_sum_like_TWO, # TODO is this necessary? collapse_sum_to, concat, cos, @@ -102,6 +101,7 @@ greater_equal, hint_on_device, image, + index_tensor, # TODO do something with this or remove? invoke_closure, invoke_pure_closure, isfinite, @@ -735,8 +735,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "call_builtin_with_ctx", "ceil", "clip", - collapse_sum_like, - "collapse_sum_like_TWO", # TODO is this necessary? + "collapse_sum_like", "collapse_sum_to", "concat", "cos", @@ -785,6 +784,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "hexagon", "hint_on_device", "image", + "index_tensor", # TODO keep or remove? "invoke_closure", "invoke_pure_closure", "isfinite", diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py index a838e980b209..02f03cc2cfab 100644 --- a/python/tvm/topi/transform.py +++ b/python/tvm/topi/transform.py @@ -1053,14 +1053,6 @@ def _apply_trilu(*indices): return te.compute(data.shape, _apply_trilu, name="trilu", tag=topi.tag.ELEMWISE) - -def collapse_sum_like_TWO(data, indices): - print("IN TOPI: !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") - print(data) - print(indices) - # return data # doesn't work - return topi.sum(data, axis=[0]) - def index_tensor(data, indices): """ TODO docstring - If 'indices' is a list/tuple of length > 1, we interpret that as multiple advanced indices, @@ -1083,13 +1075,15 @@ def index_tensor(data, indices): """ # The typical pattern is to define the new output via te.compute, # with a lambda that describes the element-wise operation. - return te.compute( - data.shape, - lambda *indices: data(*indices) + tvm.tir.const(1, data.dtype), - name="dummy_add_one", - # For a simple element-wise operator, you can use tag=topi.tag.ELEMWISE - tag="elemwise", - ) + # return te.compute( + # data.shape, + # lambda *indices: data(*indices) + tvm.tir.const(1, data.dtype), + # name="dummy_add_one", + # # For a simple element-wise operator, you can use tag=topi.tag.ELEMWISE + # tag="elemwise", + # ) # TODO this should work + + return topi.sum(data, axis=[0]) # TODO this also works # return data diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index df246c9eab2e..dd6a8bd98449 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -476,87 +476,43 @@ TVM_REGISTER_OP("relax.flatten") .set_attr("FPurity", Bool(true)); /* relax.index_tensor */ -TVM_REGISTER_NODE_TYPE(IndexTensorAttrs); - -Expr index_tensor(Expr x, Array indices) { - auto attrs = make_object(); - attrs->indices = std::move(indices); - +Expr index_tensor(Expr data, Expr indices) { static const Op& op = Op::Get("relax.index_tensor"); - return Call(op, {std::move(x)}, Attrs(attrs), {}); + return Call(op, {std::move(data), std::move(indices)}, Attrs(), {}); } TVM_REGISTER_GLOBAL("relax.op.index_tensor").set_body_typed(index_tensor); -// TODO understand every line here? Is this all correct? StructInfo InferStructInfoIndexTensor(const Call& call, const BlockBuilder& ctx) { - CheckNumArguments(call, ctx); - TensorStructInfo data_sinfo = GetInputTensorStructInfo(call, 0, ctx); - - // TODO the commented out checks below fail, understand why! - // // StructInfo inference when the index is a PrimValue is equivalent - // // to that of a scalar (0-d) tensor. - // TensorStructInfo indices_sinfo = [&]() { - // auto arg = call->args[0]; // TODO changed this from 1 to 0, is that ok? - // auto sinfo = GetStructInfo(arg); - // // TODO update the condition below. The indices argument should always be a tensor, it cannot be - // // a scalar value - // if (auto tensor_sinfo = sinfo.as()) { - // return tensor_sinfo.value(); - // } else if (auto prim_sinfo = sinfo.as()) { - // return TensorStructInfo(ShapeExpr(Array{}), prim_sinfo->dtype); - // } else { - // ctx->ReportFatal(Diagnostic::Error(call) - // << "Operator " << call->op << " requires the indices argument to be " - // << "either a tensor or a scalar value. " - // << "However, argument " << arg << " has struct info " << sinfo); - // // Unreachable, but [[noreturn]] attribute on virtual function - // // `ReportFatal` is insufficient to silence -Wreturn-type, as - // // child class might not be [[noreturn]]. - // return TensorStructInfo(); - // } - // }(); - - // if (indices_sinfo->IsUnknownDtype()) { - // // TODO(tvm-team): Do we have an equivalent of `ctx->ReportFatal` for warning? - // LOG(WARNING) << "Data type of indice has not been specified. Assume it has an integer type."; - // } else if (!(indices_sinfo->dtype.is_int() || indices_sinfo->dtype.is_uint())) { - // ctx->ReportFatal( - // Diagnostic::Error(call) - // << "Index Tensor op requires the input indices to have integer dtype. However, the " - // "given indices dtype is " - // << indices_sinfo->dtype); - // } - - // if (data_sinfo->IsUnknownNdim() || indices_sinfo->IsUnknownNdim()) { - // return TensorStructInfo(data_sinfo->dtype, kUnknownNDim, data_sinfo->vdevice); - // } - - // const auto* attrs = call->attrs.as(); - - // const auto* data_shape = data_sinfo->shape.as(); - // const auto* indices_shape = indices_sinfo->shape.as(); - // if (data_shape == nullptr || indices_shape == nullptr) { - // return TensorStructInfo(data_sinfo->dtype, indices_sinfo->ndim + data_sinfo->ndim - 1, - // data_sinfo->vdevice); - // } - - // TODO can we do better than kUnknownNDim, and instead do something like this for the output - // shape? Array output_shape; for (int i = 0; i < data_sinfo->ndim; i++) { - // if (i == axis) { - // for (int j = 0; j < indices_sinfo->ndim; j++) - // output_shape.push_back(indices_shape->values[j]); - // } else { - // output_shape.push_back(data_shape->values[i]); - // } - // } - return TensorStructInfo(data_sinfo->dtype, kUnknownNDim, data_sinfo->vdevice); + // TODO most of this is arbitrarily copied from collapse_sum_like. Need to understand what we + // actually need + Array input_sinfo = GetInputTensorStructInfo(call, ctx); + TensorStructInfo data_sinfo = input_sinfo[0]; + TensorStructInfo indices_sinfo = input_sinfo[1]; + + DataType output_dtype = data_sinfo->dtype; + + Optional> data_shape_value; + if (data_sinfo->shape.defined()) { + data_shape_value = GetStructInfoAs(data_sinfo->shape.value())->values; + } + Optional> indices_shape_value; + if (indices_sinfo->shape.defined()) { + indices_shape_value = + GetStructInfoAs(indices_sinfo->shape.value())->values; + } + + if (indices_sinfo->shape.defined()) { + return TensorStructInfo(indices_sinfo->shape.value(), output_dtype, indices_sinfo->vdevice); + } else { + return TensorStructInfo(output_dtype, indices_sinfo->ndim, indices_sinfo->vdevice); + } } TVM_REGISTER_OP("relax.index_tensor") - .set_attrs_type() - .set_num_inputs(1) - .add_argument("x", "Tensor", "The input tensor.") + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("indices", "Tensor", "The indices tensor.") .set_attr("FInferStructInfo", InferStructInfoIndexTensor) .set_attr("FPurity", Bool(true)); @@ -1325,52 +1281,6 @@ TVM_REGISTER_OP("relax.collapse_sum_like") .set_attr("FInferStructInfo", InferStructInfoCollapseSumLike) .set_attr("FPurity", Bool(true)); -/* relax.collapse_sum_like_TWO */ -Expr collapse_sum_like_TWO(Expr data, Expr collapse_target) { - static const Op& op = Op::Get("relax.collapse_sum_like_TWO"); - return Call(op, {std::move(data), std::move(collapse_target)}, Attrs(), {}); -} - -TVM_REGISTER_GLOBAL("relax.op.collapse_sum_like_TWO").set_body_typed(collapse_sum_like_TWO); - -StructInfo InferStructInfoCollapseSumLikeTWO(const Call& call, const BlockBuilder& ctx) { - Array input_sinfo = GetInputTensorStructInfo(call, ctx); - TensorStructInfo data_sinfo = input_sinfo[0]; - TensorStructInfo collapse_target_sinfo = input_sinfo[1]; - - DataType output_dtype = data_sinfo->dtype; - - Optional> data_shape_value; - if (data_sinfo->shape.defined()) { - data_shape_value = GetStructInfoAs(data_sinfo->shape.value())->values; - } - Optional> collapse_target_shape_value; - if (collapse_target_sinfo->shape.defined()) { - collapse_target_shape_value = - GetStructInfoAs(collapse_target_sinfo->shape.value())->values; - } - - if (data_shape_value.defined() && collapse_target_shape_value.defined()) { - CheckCollapseShape(call, ctx, data_shape_value.value(), collapse_target_shape_value.value()); - } - - if (collapse_target_sinfo->shape.defined()) { - return TensorStructInfo(collapse_target_sinfo->shape.value(), output_dtype, - collapse_target_sinfo->vdevice); - } else { - return TensorStructInfo(output_dtype, collapse_target_sinfo->ndim, - collapse_target_sinfo->vdevice); - } -} - -TVM_REGISTER_OP("relax.collapse_sum_like_TWO") - .set_num_inputs(2) - .add_argument("data", "Tensor", "The input tensor.") - .add_argument("collapse_target", "Tensor", - "The tensor whose shape is the shape to collapse to.") - .set_attr("FInferStructInfo", InferStructInfoCollapseSumLikeTWO) - .set_attr("FPurity", Bool(true)); - /* relax.collapse_sum_to */ Expr collapse_sum_to(Expr data, Expr shape) { static const Op& op = Op::Get("relax.collapse_sum_to"); diff --git a/src/relax/op/tensor/manipulate.h b/src/relax/op/tensor/manipulate.h index f22719faab4c..7774dd45c3b8 100644 --- a/src/relax/op/tensor/manipulate.h +++ b/src/relax/op/tensor/manipulate.h @@ -220,7 +220,7 @@ Expr gather_nd(Expr data, Expr indices, int batch_dims = 0); * The output shape is batch_dims + indices.shape[:-1] + data.shape[batch_dims + * indices.shape[-1]:] */ -Expr index_tensor(Expr data, Array indices); +Expr index_tensor(Expr data, Expr indices); /*! * \brief Scatter updates into an array according to indices. From 4df949acc80f66e08dd78f03cbdf23b9ba213bc8 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 13 Apr 2025 17:44:24 -0400 Subject: [PATCH 044/105] able to get an ouptut from index.tenosr ! --- include/tvm/relax/attrs/manipulate.h | 13 +++++++------ python/tvm/relax/op/op_attrs.py | 7 +++---- src/relax/op/tensor/manipulate.h | 9 --------- 3 files changed, 10 insertions(+), 19 deletions(-) diff --git a/include/tvm/relax/attrs/manipulate.h b/include/tvm/relax/attrs/manipulate.h index 91fbee8c591d..258aa20d4703 100644 --- a/include/tvm/relax/attrs/manipulate.h +++ b/include/tvm/relax/attrs/manipulate.h @@ -169,13 +169,14 @@ struct GatherNDAttrs : public tvm::AttrsNode { } }; // struct GatherNDAttrs +// TODO maybe we don't need this? /*! \brief Attributes used in index_tensor operators */ -struct IndexTensorAttrs : public tvm::AttrsNode { - Array indices; // TODO will need to extend this, since could be an array of arrays? - TVM_DECLARE_ATTRS(IndexTensorAttrs, "relax.attrs.IndexTensorAttrs") { - TVM_ATTR_FIELD(indices).describe("The indices to select."); - } -}; // struct IndexTensorAttrs +// struct IndexTensorAttrs : public tvm::AttrsNode { +// Array indices; // TODO will need to extend this, since could be an array of arrays? +// TVM_DECLARE_ATTRS(IndexTensorAttrs, "relax.attrs.IndexTensorAttrs") { +// TVM_ATTR_FIELD(indices).describe("The indices to select."); +// } +// }; // struct IndexTensorAttrs /*! \brief Attributes used in scatter_elements operators */ struct ScatterElementsAttrs : public tvm::AttrsNode { diff --git a/python/tvm/relax/op/op_attrs.py b/python/tvm/relax/op/op_attrs.py index b80ef3512d4d..58fe6b2f5f6d 100644 --- a/python/tvm/relax/op/op_attrs.py +++ b/python/tvm/relax/op/op_attrs.py @@ -185,7 +185,6 @@ class FlipAttrs(Attrs): # TODO is this needed? It looks like not all ops are here -@tvm._ffi.register_object("relax.attrs.IndexTensorAttrs") -class IndexTensorAttrs(Attrs): - """Attributes used in index_tensor operator""" - +# @tvm._ffi.register_object("relax.attrs.IndexTensorAttrs") +# class IndexTensorAttrs(Attrs): +# """Attributes used in index_tensor operator""" diff --git a/src/relax/op/tensor/manipulate.h b/src/relax/op/tensor/manipulate.h index 7774dd45c3b8..0c7d482ffacb 100644 --- a/src/relax/op/tensor/manipulate.h +++ b/src/relax/op/tensor/manipulate.h @@ -127,15 +127,6 @@ Expr squeeze(Expr x, Optional> axis); */ Expr collapse_sum_like(Expr data, Expr collapse_target); -/*! - * \brief Return a summation of data to the shape of collapse_target. - * For details, please see the operator `relax.collapse_sum_to`. - * \param data The input tensor. - * \param collapse_target The tensor whose shape is the shape to collapse to. - * \return The result tensor after summation. - */ -Expr collapse_sum_like_TWO(Expr data, Expr collapse_target); - /*! * \brief Return a summation of data to the given shape. * collapse_sum_to is intended as the backward operator of broadcast_to and From d9f25899927f1b5d67c82f9ca27ccc42f1ab6d62 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 13 Apr 2025 19:00:38 -0400 Subject: [PATCH 045/105] need to isolate what goes wrong in building --- python/tvm/topi/transform.py | 182 ++++++++++++++++++----------------- 1 file changed, 94 insertions(+), 88 deletions(-) diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py index 02f03cc2cfab..bfb94e186251 100644 --- a/python/tvm/topi/transform.py +++ b/python/tvm/topi/transform.py @@ -1081,98 +1081,104 @@ def index_tensor(data, indices): # name="dummy_add_one", # # For a simple element-wise operator, you can use tag=topi.tag.ELEMWISE # tag="elemwise", - # ) # TODO this should work + # ) # TODO this also works - return topi.sum(data, axis=[0]) # TODO this also works + # return topi.sum(data, axis=[0]) # TODO this also works # return data # TODO uncomment - # # Helper to fix negative indices: out_idx = where(idx<0, idx+dim_size, idx) - # def _fix_negatives(idx_t, dim_size): - # # idx_t, dim_size are tvm.te.Tensor or integers. - # # We'll broadcast if needed. We can do so by calling topi.where(...) with the condition - # # (idx_t < 0). - # # For static shape, `dim_size` could be int. For dynamic shape, dim_size might be a Tensor. - # # Suppose dim_size is int here. Then we can just do: - - # # TODO uncomment - # # zero_t = topi.full_like(idx_t, 0) - # # dim_size_t = topi.full_like(idx_t, dim_size) # broadcast if needed - # # return topi.where(topi.less(idx_t, zero_t), topi.add(idx_t, dim_size_t), idx_t) - # return topi.where(idx_t < 0, idx_t + dim_size, idx_t) - - # # --- Check whether indices is multiple advanced indices or single advanced index. --- - # if isinstance(indices, (list, tuple)) and len(indices) > 1: - # # ----------------------------------------------------------- - # # CASE B: multiple advanced indices - # # ----------------------------------------------------------- - # # Suppose each sub_i is a tvm.te.Tensor of integer type, indexing a separate dimension. - # # We want to broadcast them to a common shape (if not already), - # # fix negative indices, then use topi.adv_index. - # idx_list = list(indices) - - # # 1) Determine broadcast shape. For simplicity we can rely on `topi.adv_index` automatically - # # broadcasting the indices if they differ in shape. If you need explicit broadcasting, - # # you can do so via topi utilities (e.g. topi.broadcast_to). - # # Then fix negative indexing dimensionwise. - # # data.shape is e.g. [d0, d1, d2, ...], so for the i-th advanced index, dimension = data.shape[i]. - # # We fix negative indexing if desired: - # final_indices = [] - # for i, idx_t in enumerate(idx_list): - # # If we want negative fix, do it here: - # dim_size = data.shape[i] # a PrimExpr - # fixed = _fix_negatives(idx_t, dim_size) - # final_indices.append(fixed) - - # # 2) Use topi.adv_index - # # This will produce a new tensor with shape = broadcast of final_indices. - # result = topi.adv_index(data, final_indices) - # return result - - # else: - # # ----------------------------------------------------------- - # # CASE A: single advanced index - # # ----------------------------------------------------------- - # # We interpret 'indices' as a single integer-tensor for dimension=0 indexing. - # # So the result shape is [*indices_shape, leftover_dims], with leftover_dims = data.shape[1:]. - # # - # # Steps, paralleling the Python: - # # 1) If the first dimension of indices is 1, remove it => topi.squeeze if we want. - # # 2) Flatten => topi.reshape - # # 3) fix negative indices => topi.where - # # 4) gather => topi.take(..., axis=0) - # # 5) reshape => combine advanced dims + leftover dims - # idx_t = indices if isinstance(indices, te.Tensor) else indices[0] - - # # Possibly remove leading dimension if shape[0]==1: - # if len(idx_t.shape) > 0: - # first_dim = idx_t.shape[0] - # if isinstance(first_dim, int) and first_dim == 1: - # # topi.squeeze can remove exactly one axis: - # idx_t = topi.squeeze(idx_t, axis=[0]) - # else: - # # If we suspect it's dynamic, we can check with a small schedule or approach, - # # but here's the naive approach: we skip if the dimension is unknown - # pass - - # # Flatten - # flattened = topi.reshape(idx_t, (-1,)) - - # # fix negative indexing - # # data.shape[0] is batch dimension - # fixed = _fix_negatives(flattened, data.shape[0]) - - # # gather => topi.take - # # out shape = [len_of_fixed] + leftover_dims - # picked = topi.take(data, fixed, axis=0) - - # # final reshape => idx_t original shape (after squeeze) + leftover - # # we can get idx_t's shape with topi.shape if dynamic, or known statically - # adv_shape = tuple(idx_t.shape) # or topi.shape(idx_t) if dynamic - # leftover_dims = data.shape[1:] - # final_shape = adv_shape + leftover_dims - # result = topi.reshape(picked, final_shape) - # return result + # Helper to fix negative indices: out_idx = where(idx<0, idx+dim_size, idx) + def _fix_negatives(idx_t, dim_size): + # idx_t, dim_size are tvm.te.Tensor or integers. + # We'll broadcast if needed. We can do so by calling topi.where(...) with the condition + # (idx_t < 0). + # For static shape, `dim_size` could be int. For dynamic shape, dim_size might be a Tensor. + # Suppose dim_size is int here. Then we can just do: + + # TODO uncomment + zero_t = topi.full_like(idx_t, 0) + dim_size_t = topi.full_like(idx_t, dim_size) # broadcast if needed + return topi.where(topi.less(idx_t, zero_t), topi.add(idx_t, dim_size_t), idx_t) + + # --- Check whether indices is multiple advanced indices or single advanced index. --- + if isinstance(indices, (list, tuple)) and len(indices) > 1: + # ----------------------------------------------------------- + # CASE B: multiple advanced indices + # ----------------------------------------------------------- + # Suppose each sub_i is a tvm.te.Tensor of integer type, indexing a separate dimension. + # We want to broadcast them to a common shape (if not already), + # fix negative indices, then use topi.adv_index. + idx_list = list(indices) + + # 1) Determine broadcast shape. For simplicity we can rely on `topi.adv_index` automatically + # broadcasting the indices if they differ in shape. If you need explicit broadcasting, + # you can do so via topi utilities (e.g. topi.broadcast_to). + # Then fix negative indexing dimensionwise. + # data.shape is e.g. [d0, d1, d2, ...], so for the i-th advanced index, dimension = data.shape[i]. + # We fix negative indexing if desired: + final_indices = [] + for i, idx_t in enumerate(idx_list): + # If we want negative fix, do it here: + dim_size = data.shape[i] # a PrimExpr + fixed = _fix_negatives(idx_t, dim_size) + final_indices.append(fixed) + + # 2) Use topi.adv_index + # This will produce a new tensor with shape = broadcast of final_indices. + result = topi.adv_index(data, final_indices) + return result + + else: + # ----------------------------------------------------------- + # CASE A: single advanced index + # ----------------------------------------------------------- + # We interpret 'indices' as a single integer-tensor for dimension=0 indexing. + # So the result shape is [*indices_shape, leftover_dims], with leftover_dims = data.shape[1:]. + # + # Steps, paralleling the Python: + # 1) If the first dimension of indices is 1, remove it => topi.squeeze if we want. + # 2) Flatten => topi.reshape + # 3) fix negative indices => topi.where + # 4) gather => topi.take(..., axis=0) + # 5) reshape => combine advanced dims + leftover dims + idx_t = indices if isinstance(indices, te.Tensor) else indices[0] + return topi.sum(data, axis=[0]) # TODO this also works + + + # Possibly remove leading dimension if shape[0]==1: + if len(idx_t.shape) > 0: + first_dim = idx_t.shape[0] + if isinstance(first_dim, int) and first_dim == 1: + # topi.squeeze can remove exactly one axis: + idx_t = topi.squeeze(idx_t, axis=[0]) + else: + # If we suspect it's dynamic, we can check with a small schedule or approach, + # but here's the naive approach: we skip if the dimension is unknown + pass + + # Flatten + flattened = topi.reshape(idx_t, (-1,)) + + # fix negative indexing + # data.shape[0] is batch dimension + fixed = _fix_negatives(flattened, data.shape[0]) + + # gather => topi.take + # out shape = [len_of_fixed] + leftover_dims + picked = topi.take(data, fixed, axis=0) + + # final reshape => idx_t original shape (after squeeze) + leftover + # we can get idx_t's shape with topi.shape if dynamic, or known statically + adv_shape = tuple(idx_t.shape) # or topi.shape(idx_t) if dynamic + leftover_dims = tuple(data.shape[1:]) + print('adv_shape type', type(adv_shape)) + print('leftover_dims type', type(leftover_dims)) + print("A #############################################") + final_shape = adv_shape + leftover_dims + print("B #############################################") + result = topi.reshape(picked, final_shape) + print("C #############################################") + return result From e6ae241c405158695c1d9a21d17de2ca1a3e16ce Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 13 Apr 2025 19:01:03 -0400 Subject: [PATCH 046/105] ok --- python/tvm/topi/transform.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py index bfb94e186251..267420a76218 100644 --- a/python/tvm/topi/transform.py +++ b/python/tvm/topi/transform.py @@ -1145,7 +1145,6 @@ def _fix_negatives(idx_t, dim_size): # 4) gather => topi.take(..., axis=0) # 5) reshape => combine advanced dims + leftover dims idx_t = indices if isinstance(indices, te.Tensor) else indices[0] - return topi.sum(data, axis=[0]) # TODO this also works # Possibly remove leading dimension if shape[0]==1: @@ -1159,6 +1158,8 @@ def _fix_negatives(idx_t, dim_size): # but here's the naive approach: we skip if the dimension is unknown pass + return topi.sum(data, axis=[0]) # TODO this also works + # Flatten flattened = topi.reshape(idx_t, (-1,)) From 8db2a09a58bdea1800ef8cf1eb19c6ecda14f7fd Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 13 Apr 2025 19:01:25 -0400 Subject: [PATCH 047/105] ok --- python/tvm/topi/transform.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py index 267420a76218..584f663bb157 100644 --- a/python/tvm/topi/transform.py +++ b/python/tvm/topi/transform.py @@ -1158,10 +1158,10 @@ def _fix_negatives(idx_t, dim_size): # but here's the naive approach: we skip if the dimension is unknown pass - return topi.sum(data, axis=[0]) # TODO this also works # Flatten flattened = topi.reshape(idx_t, (-1,)) + return topi.sum(data, axis=[0]) # TODO this also works # fix negative indexing # data.shape[0] is batch dimension From 301bc403f7fd0a718fe5072c93716907d2f25c83 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 13 Apr 2025 19:01:40 -0400 Subject: [PATCH 048/105] ok --- python/tvm/topi/transform.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py index 584f663bb157..ee105c128fcb 100644 --- a/python/tvm/topi/transform.py +++ b/python/tvm/topi/transform.py @@ -1161,11 +1161,11 @@ def _fix_negatives(idx_t, dim_size): # Flatten flattened = topi.reshape(idx_t, (-1,)) - return topi.sum(data, axis=[0]) # TODO this also works # fix negative indexing # data.shape[0] is batch dimension fixed = _fix_negatives(flattened, data.shape[0]) + return topi.sum(data, axis=[0]) # TODO this also works # gather => topi.take # out shape = [len_of_fixed] + leftover_dims From 57d01ffde009c1d47d674d1ad8293749b71a38ec Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 13 Apr 2025 19:02:00 -0400 Subject: [PATCH 049/105] ok --- python/tvm/topi/transform.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py index ee105c128fcb..3ddba9c4d1f7 100644 --- a/python/tvm/topi/transform.py +++ b/python/tvm/topi/transform.py @@ -1165,11 +1165,11 @@ def _fix_negatives(idx_t, dim_size): # fix negative indexing # data.shape[0] is batch dimension fixed = _fix_negatives(flattened, data.shape[0]) - return topi.sum(data, axis=[0]) # TODO this also works # gather => topi.take # out shape = [len_of_fixed] + leftover_dims picked = topi.take(data, fixed, axis=0) + return topi.sum(data, axis=[0]) # TODO this also works # final reshape => idx_t original shape (after squeeze) + leftover # we can get idx_t's shape with topi.shape if dynamic, or known statically From d04a814da932aed30bc74c277115504d4101d3cc Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 13 Apr 2025 19:04:50 -0400 Subject: [PATCH 050/105] ok --- python/tvm/topi/transform.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py index 3ddba9c4d1f7..c3c0abc7f356 100644 --- a/python/tvm/topi/transform.py +++ b/python/tvm/topi/transform.py @@ -1169,11 +1169,14 @@ def _fix_negatives(idx_t, dim_size): # gather => topi.take # out shape = [len_of_fixed] + leftover_dims picked = topi.take(data, fixed, axis=0) - return topi.sum(data, axis=[0]) # TODO this also works # final reshape => idx_t original shape (after squeeze) + leftover # we can get idx_t's shape with topi.shape if dynamic, or known statically adv_shape = tuple(idx_t.shape) # or topi.shape(idx_t) if dynamic + + return topi.sum(data, axis=[0]) # TODO this also works + + leftover_dims = tuple(data.shape[1:]) print('adv_shape type', type(adv_shape)) print('leftover_dims type', type(leftover_dims)) From 098c5418c88ecf599651944cbc4464ce324aac90 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 13 Apr 2025 19:11:35 -0400 Subject: [PATCH 051/105] calculate 3 results. result2 and result3 have same type but only result2 works --- python/tvm/topi/transform.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py index c3c0abc7f356..67dbacd5c475 100644 --- a/python/tvm/topi/transform.py +++ b/python/tvm/topi/transform.py @@ -1174,7 +1174,6 @@ def _fix_negatives(idx_t, dim_size): # we can get idx_t's shape with topi.shape if dynamic, or known statically adv_shape = tuple(idx_t.shape) # or topi.shape(idx_t) if dynamic - return topi.sum(data, axis=[0]) # TODO this also works leftover_dims = tuple(data.shape[1:]) @@ -1185,4 +1184,15 @@ def _fix_negatives(idx_t, dim_size): print("B #############################################") result = topi.reshape(picked, final_shape) print("C #############################################") - return result + # return result + result2 = topi.sum(data, axis=[0]) # TODO this also works + result3 = topi.sum(result, axis=[0]) # TODO this also works + + print("type(result)", type(result)) + print("type(result2)", type(result2)) + print("type(result3)", type(result3)) + print("result", result) # result Tensor(shape=[T.int64(2), T.int64(4)], op.name=T_reshape) + print("result2", result2) # result2 Tensor(shape=[T.int64(4)], op.name=x_red) + print("result3", result3) # result3 Tensor(shape=[T.int64(4)], op.name=T_reshape_red) + + return result2 From 82775dc045e84e56719cd5b06016b03c7eefb040 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 13 Apr 2025 19:32:51 -0400 Subject: [PATCH 052/105] gets correctness with topi.take --- python/tvm/topi/transform.py | 125 +---------------------------------- 1 file changed, 3 insertions(+), 122 deletions(-) diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py index 67dbacd5c475..2f78edcd4145 100644 --- a/python/tvm/topi/transform.py +++ b/python/tvm/topi/transform.py @@ -1073,126 +1073,7 @@ def index_tensor(data, indices): 2. Remove exactly one leading dimension of size=1, if present. (Matches PyTorch's shape rule.) 3. Flatten -> fix negative indices -> index_select -> reshape. """ - # The typical pattern is to define the new output via te.compute, - # with a lambda that describes the element-wise operation. - # return te.compute( - # data.shape, - # lambda *indices: data(*indices) + tvm.tir.const(1, data.dtype), - # name="dummy_add_one", - # # For a simple element-wise operator, you can use tag=topi.tag.ELEMWISE - # tag="elemwise", - # ) # TODO this also works - - # return topi.sum(data, axis=[0]) # TODO this also works - - # return data - - # TODO uncomment - - - # Helper to fix negative indices: out_idx = where(idx<0, idx+dim_size, idx) - def _fix_negatives(idx_t, dim_size): - # idx_t, dim_size are tvm.te.Tensor or integers. - # We'll broadcast if needed. We can do so by calling topi.where(...) with the condition - # (idx_t < 0). - # For static shape, `dim_size` could be int. For dynamic shape, dim_size might be a Tensor. - # Suppose dim_size is int here. Then we can just do: - - # TODO uncomment - zero_t = topi.full_like(idx_t, 0) - dim_size_t = topi.full_like(idx_t, dim_size) # broadcast if needed - return topi.where(topi.less(idx_t, zero_t), topi.add(idx_t, dim_size_t), idx_t) - - # --- Check whether indices is multiple advanced indices or single advanced index. --- - if isinstance(indices, (list, tuple)) and len(indices) > 1: - # ----------------------------------------------------------- - # CASE B: multiple advanced indices - # ----------------------------------------------------------- - # Suppose each sub_i is a tvm.te.Tensor of integer type, indexing a separate dimension. - # We want to broadcast them to a common shape (if not already), - # fix negative indices, then use topi.adv_index. - idx_list = list(indices) - - # 1) Determine broadcast shape. For simplicity we can rely on `topi.adv_index` automatically - # broadcasting the indices if they differ in shape. If you need explicit broadcasting, - # you can do so via topi utilities (e.g. topi.broadcast_to). - # Then fix negative indexing dimensionwise. - # data.shape is e.g. [d0, d1, d2, ...], so for the i-th advanced index, dimension = data.shape[i]. - # We fix negative indexing if desired: - final_indices = [] - for i, idx_t in enumerate(idx_list): - # If we want negative fix, do it here: - dim_size = data.shape[i] # a PrimExpr - fixed = _fix_negatives(idx_t, dim_size) - final_indices.append(fixed) - - # 2) Use topi.adv_index - # This will produce a new tensor with shape = broadcast of final_indices. - result = topi.adv_index(data, final_indices) - return result + # flattened = topi.reshape(indices, (-1,)) + picked = topi.take(data, indices, axis=0) + return picked - else: - # ----------------------------------------------------------- - # CASE A: single advanced index - # ----------------------------------------------------------- - # We interpret 'indices' as a single integer-tensor for dimension=0 indexing. - # So the result shape is [*indices_shape, leftover_dims], with leftover_dims = data.shape[1:]. - # - # Steps, paralleling the Python: - # 1) If the first dimension of indices is 1, remove it => topi.squeeze if we want. - # 2) Flatten => topi.reshape - # 3) fix negative indices => topi.where - # 4) gather => topi.take(..., axis=0) - # 5) reshape => combine advanced dims + leftover dims - idx_t = indices if isinstance(indices, te.Tensor) else indices[0] - - - # Possibly remove leading dimension if shape[0]==1: - if len(idx_t.shape) > 0: - first_dim = idx_t.shape[0] - if isinstance(first_dim, int) and first_dim == 1: - # topi.squeeze can remove exactly one axis: - idx_t = topi.squeeze(idx_t, axis=[0]) - else: - # If we suspect it's dynamic, we can check with a small schedule or approach, - # but here's the naive approach: we skip if the dimension is unknown - pass - - - # Flatten - flattened = topi.reshape(idx_t, (-1,)) - - # fix negative indexing - # data.shape[0] is batch dimension - fixed = _fix_negatives(flattened, data.shape[0]) - - # gather => topi.take - # out shape = [len_of_fixed] + leftover_dims - picked = topi.take(data, fixed, axis=0) - - # final reshape => idx_t original shape (after squeeze) + leftover - # we can get idx_t's shape with topi.shape if dynamic, or known statically - adv_shape = tuple(idx_t.shape) # or topi.shape(idx_t) if dynamic - - - - leftover_dims = tuple(data.shape[1:]) - print('adv_shape type', type(adv_shape)) - print('leftover_dims type', type(leftover_dims)) - print("A #############################################") - final_shape = adv_shape + leftover_dims - print("B #############################################") - result = topi.reshape(picked, final_shape) - print("C #############################################") - # return result - result2 = topi.sum(data, axis=[0]) # TODO this also works - result3 = topi.sum(result, axis=[0]) # TODO this also works - - print("type(result)", type(result)) - print("type(result2)", type(result2)) - print("type(result3)", type(result3)) - print("result", result) # result Tensor(shape=[T.int64(2), T.int64(4)], op.name=T_reshape) - print("result2", result2) # result2 Tensor(shape=[T.int64(4)], op.name=x_red) - print("result3", result3) # result3 Tensor(shape=[T.int64(4)], op.name=T_reshape_red) - - return result2 From 82e9edd59ea4b54626ae87a3eea38f72c4d892bc Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 13 Apr 2025 19:48:21 -0400 Subject: [PATCH 053/105] passing index1D and index2D testsgit status! --- .../relax/test_from_exported_to_cuda.py | 37 +++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index 8405f48576d8..ef62e67844d5 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -63,6 +63,43 @@ def assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, tar np.testing.assert_allclose(actual=actual, desired=desired, rtol=1e-5, atol=1e-5) +# Test index.Tensor # TODO aggregate into one big tet + +@tvm.testing.parametrize_targets("cuda") +def test_index_tensor0(target, dev): + class IndexModel(nn.Module): + def __init__(self): + super().__init__() + self.position_ids = torch.tensor([0]) + + def forward(self, x): + return x[self.position_ids] + + torch_module = IndexModel().eval() + + raw_data = np.random.rand(3,3).astype("float32") + + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + +@tvm.testing.parametrize_targets("cuda") +def test_index_tensor1(target, dev): + class IndexModel(nn.Module): + def __init__(self): + super().__init__() + self.position_ids = torch.tensor([[0]]) + + def forward(self, x): + return x[self.position_ids] + + torch_module = IndexModel().eval() + + raw_data = np.random.rand(2,3).astype("float32") + + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + + @tvm.testing.parametrize_targets("cuda") def test_tensor_clamp(target, dev): class ClampBothTensor(torch.nn.Module): From 3df31fce8c8bcbdda8d5de99ac60a75c1eeaf31f Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 13 Apr 2025 19:50:27 -0400 Subject: [PATCH 054/105] first 3 tests pass --- .../relax/test_from_exported_to_cuda.py | 23 +++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index ef62e67844d5..40963aa0dc5b 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -67,7 +67,7 @@ def assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, tar @tvm.testing.parametrize_targets("cuda") def test_index_tensor0(target, dev): - class IndexModel(nn.Module): + class IndexModel0(nn.Module): def __init__(self): super().__init__() self.position_ids = torch.tensor([0]) @@ -75,7 +75,7 @@ def __init__(self): def forward(self, x): return x[self.position_ids] - torch_module = IndexModel().eval() + torch_module = IndexModel0().eval() raw_data = np.random.rand(3,3).astype("float32") @@ -84,7 +84,7 @@ def forward(self, x): @tvm.testing.parametrize_targets("cuda") def test_index_tensor1(target, dev): - class IndexModel(nn.Module): + class IndexModel1(nn.Module): def __init__(self): super().__init__() self.position_ids = torch.tensor([[0]]) @@ -92,13 +92,28 @@ def __init__(self): def forward(self, x): return x[self.position_ids] - torch_module = IndexModel().eval() + torch_module = IndexModel1().eval() raw_data = np.random.rand(2,3).astype("float32") assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) +@tvm.testing.parametrize_targets("cuda") +def test_index_tensor2(target, dev): + class IndexTensorModel2(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x[torch.tensor([0,2])] + + torch_module = IndexTensorModel2().eval() + + raw_data = np.random.rand(3,4).astype("float32") + + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + @tvm.testing.parametrize_targets("cuda") def test_tensor_clamp(target, dev): From 142a16dde4b2a137301c542c4b951db6cdcc0d1b Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 13 Apr 2025 20:04:13 -0400 Subject: [PATCH 055/105] other test --- .../relax/test_from_exported_to_cuda.py | 43 +++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index 40963aa0dc5b..163cb62f9473 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -115,6 +115,49 @@ def forward(self, x): assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) +@tvm.testing.parametrize_targets("cuda") +def test_index_tensor3(target, dev): + class IndexTensorModel3(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x[[[0,1,2,3], [1,2,3,4], [2,3,4,0]]] # both args[0] and indices are expr.Var + + torch_module = IndexTensorModel3().eval() + raw_data = np.random.rand(5,5,5,5).astype("float32") + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + +@tvm.testing.parametrize_targets("cuda") +def test_index_tensor4(target, dev): + class IndexTensorModel4(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x[[[[0,1],[2,3]],[[4,5],[6,7]],[[2,4],[1,2]],[[0,4],[0,3]]]] + + torch_module = IndexTensorModel4().eval() + raw_data = np.random.rand(5,5,5,5).astype("float32") + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + + +@tvm.testing.parametrize_targets("cuda") +def test_index_tensor5(target, dev): + class IndexTensorModel5(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x[[[0,1],[0,1]]] + + torch_module = IndexTensorModel5().eval() + raw_data = np.random.rand(5,5,5,5).astype("float32") + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + @tvm.testing.parametrize_targets("cuda") def test_tensor_clamp(target, dev): class ClampBothTensor(torch.nn.Module): From bfb2ea92b3363cfa2141b658ff1ace963988ad50 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 13 Apr 2025 20:14:40 -0400 Subject: [PATCH 056/105] all tests written. 0 to 5 pass, 6 to 8 fail --- .../relax/test_from_exported_to_cuda.py | 59 +++++++++++++++---- 1 file changed, 49 insertions(+), 10 deletions(-) diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index 163cb62f9473..e11afdf7b364 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -70,10 +70,9 @@ def test_index_tensor0(target, dev): class IndexModel0(nn.Module): def __init__(self): super().__init__() - self.position_ids = torch.tensor([0]) def forward(self, x): - return x[self.position_ids] + return x[torch.tensor([0])] torch_module = IndexModel0().eval() @@ -87,10 +86,9 @@ def test_index_tensor1(target, dev): class IndexModel1(nn.Module): def __init__(self): super().__init__() - self.position_ids = torch.tensor([[0]]) def forward(self, x): - return x[self.position_ids] + return x[torch.tensor([[0]])] torch_module = IndexModel1().eval() @@ -122,10 +120,10 @@ def __init__(self): super().__init__() def forward(self, x): - return x[[[0,1,2,3], [1,2,3,4], [2,3,4,0]]] # both args[0] and indices are expr.Var + return x[[[[0,2],[1,3]]]] torch_module = IndexTensorModel3().eval() - raw_data = np.random.rand(5,5,5,5).astype("float32") + raw_data = np.random.rand(5,5,5).astype("float32") assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) @@ -136,24 +134,65 @@ def __init__(self): super().__init__() def forward(self, x): - return x[[[[0,1],[2,3]],[[4,5],[6,7]],[[2,4],[1,2]],[[0,4],[0,3]]]] + return x[[[1,4]]] torch_module = IndexTensorModel4().eval() - raw_data = np.random.rand(5,5,5,5).astype("float32") + raw_data = np.random.rand(5,5,5).astype("float32") assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) - @tvm.testing.parametrize_targets("cuda") def test_index_tensor5(target, dev): class IndexTensorModel5(nn.Module): def __init__(self): super().__init__() + def forward(self, x): + return x[[[[1,2,4]]]] + + torch_module = IndexTensorModel5().eval() + raw_data = np.random.rand(5,5,5).astype("float32") + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + +@tvm.testing.parametrize_targets("cuda") +def test_index_tensor6(target, dev): + class IndexTensorModel6(nn.Module): + def __init__(self): + super().__init__() + def forward(self, x): return x[[[0,1],[0,1]]] - torch_module = IndexTensorModel5().eval() + torch_module = IndexTensorModel6().eval() + raw_data = np.random.rand(5,5,5,5).astype("float32") + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + +@tvm.testing.parametrize_targets("cuda") +def test_index_tensor7(target, dev): + class IndexTensorModel7(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x[[[0,1,2,3], [1,2,3,4], [2,3,4,0]]] # both args[0] and indices are expr.Var + + torch_module = IndexTensorModel7().eval() + raw_data = np.random.rand(5,5,5).astype("float32") + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + +@tvm.testing.parametrize_targets("cuda") +def test_index_tensor8(target, dev): + class IndexTensorModel8(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x[[[[0,1],[2,3]],[[2,3],[3,4]],[[2,4],[1,2]],[[0,4],[0,3]]]] + + torch_module = IndexTensorModel8().eval() raw_data = np.random.rand(5,5,5,5).astype("float32") assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) From 03d1124b9c539417a893f523c67497dca303c9d0 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 13 Apr 2025 20:23:43 -0400 Subject: [PATCH 057/105] added full --- .../frontend/torch/base_fx_graph_translator.py | 17 +++++++++++++++++ .../torch/exported_program_translator.py | 1 + .../tvm/relax/frontend/torch/fx_translator.py | 17 ----------------- .../python/relax/test_from_exported_to_cuda.py | 17 +++++++++++++++++ 4 files changed, 35 insertions(+), 17 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 7cc390d5e67b..898dc857794f 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1306,6 +1306,23 @@ def _fill(self, node: fx.Node) -> relax.Var: value = args[1] if isinstance(args[1], relax.Expr) else relax.const(args[1], dtype) return self.block_builder.emit(relax.op.full(x.struct_info.shape, value, dtype)) + def _full(self, node: fx.Node) -> relax.Var: + import torch + + args = self.retrieve_args(node) + size = relax.ShapeExpr(args[0] if isinstance(args[0], (list, tuple)) else (args[0],)) + dtype = self._convert_data_type( + node.kwargs.get("dtype", torch.get_default_dtype()), self.env + ) + value = args[1] if isinstance(args[1], relax.expr.Constant) else relax.const(args[1], dtype) + return self.block_builder.emit( + relax.op.full( + size, + value, + dtype, + ) + ) + def _index_select(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] dim = node.args[1] diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 9b1676a54645..3937ca260583 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -434,6 +434,7 @@ def create_convert_map( "empty.memory_format": self._empty, "empty_like.default": self._empty_like, "fill.Scalar": self._fill, + "full.default": self._full, "index_select.default": self._index_select, "lift_fresh_copy.default": self._to_copy, "new_ones.default": self._new_ones, diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index a5b50a7d1dce..80031cd7a403 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -468,23 +468,6 @@ def _inplace_fill(self, node: fx.Node) -> relax.Var: self.env[node.args[0]] = filled return filled - def _full(self, node: fx.Node) -> relax.Var: - import torch - - args = self.retrieve_args(node) - size = relax.ShapeExpr(args[0] if isinstance(args[0], (list, tuple)) else (args[0],)) - dtype = self._convert_data_type( - node.kwargs.get("dtype", torch.get_default_dtype()), self.env - ) - value = args[1] if isinstance(args[1], relax.expr.Constant) else relax.const(args[1], dtype) - return self.block_builder.emit( - relax.op.full( - size, - value, - dtype, - ) - ) - def _inplace_masked_fill(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] mask = self.env[node.args[1]] diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index e11afdf7b364..5e9932cf9051 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -63,6 +63,23 @@ def assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, tar np.testing.assert_allclose(actual=actual, desired=desired, rtol=1e-5, atol=1e-5) + +@tvm.testing.parametrize_targets("cuda") +def test_full(target, dev): + class FullModel(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.full((2, 3), 3.141592) + + torch_module = FullModel().eval() + + raw_data = np.random.rand(3,3).astype("float32") + + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + # Test index.Tensor # TODO aggregate into one big tet @tvm.testing.parametrize_targets("cuda") From e13fa3d0718eb75ede16a60b138e62b48360d586 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 13 Apr 2025 20:24:29 -0400 Subject: [PATCH 058/105] unit test --- tests/python/relax/test_from_exported_to_cuda.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index 8405f48576d8..43107f015313 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -63,6 +63,21 @@ def assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, tar np.testing.assert_allclose(actual=actual, desired=desired, rtol=1e-5, atol=1e-5) +@tvm.testing.parametrize_targets("cuda") +def test_full(target, dev): + class FullModel(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.full((2, 3), 3.141592) + + torch_module = FullModel().eval() + + raw_data = np.random.rand(3,3).astype("float32") + + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + @tvm.testing.parametrize_targets("cuda") def test_tensor_clamp(target, dev): class ClampBothTensor(torch.nn.Module): From 5b23c30341fff3765c1226261d84f178531485b1 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 13 Apr 2025 20:26:10 -0400 Subject: [PATCH 059/105] full.default --- .../frontend/torch/base_fx_graph_translator.py | 17 +++++++++++++++++ .../torch/exported_program_translator.py | 1 + .../tvm/relax/frontend/torch/fx_translator.py | 17 ----------------- 3 files changed, 18 insertions(+), 17 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index c9c6afd71a64..55a603e20c60 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1271,6 +1271,23 @@ def _fill(self, node: fx.Node) -> relax.Var: value = args[1] if isinstance(args[1], relax.Expr) else relax.const(args[1], dtype) return self.block_builder.emit(relax.op.full(x.struct_info.shape, value, dtype)) + def _full(self, node: fx.Node) -> relax.Var: + import torch + + args = self.retrieve_args(node) + size = relax.ShapeExpr(args[0] if isinstance(args[0], (list, tuple)) else (args[0],)) + dtype = self._convert_data_type( + node.kwargs.get("dtype", torch.get_default_dtype()), self.env + ) + value = args[1] if isinstance(args[1], relax.expr.Constant) else relax.const(args[1], dtype) + return self.block_builder.emit( + relax.op.full( + size, + value, + dtype, + ) + ) + def _index_select(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] dim = node.args[1] diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 875ec3b83ea8..26e73dd6b84b 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -433,6 +433,7 @@ def create_convert_map( "empty.memory_format": self._empty, "empty_like.default": self._empty_like, "fill.Scalar": self._fill, + "full.default": self._full, "index_select.default": self._index_select, "lift_fresh_copy.default": self._to_copy, "new_ones.default": self._new_ones, diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index a5b50a7d1dce..80031cd7a403 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -468,23 +468,6 @@ def _inplace_fill(self, node: fx.Node) -> relax.Var: self.env[node.args[0]] = filled return filled - def _full(self, node: fx.Node) -> relax.Var: - import torch - - args = self.retrieve_args(node) - size = relax.ShapeExpr(args[0] if isinstance(args[0], (list, tuple)) else (args[0],)) - dtype = self._convert_data_type( - node.kwargs.get("dtype", torch.get_default_dtype()), self.env - ) - value = args[1] if isinstance(args[1], relax.expr.Constant) else relax.const(args[1], dtype) - return self.block_builder.emit( - relax.op.full( - size, - value, - dtype, - ) - ) - def _inplace_masked_fill(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] mask = self.env[node.args[1]] From 35aee297ba2ca01dbdf2695267cf1869b399a95b Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 13 Apr 2025 20:27:32 -0400 Subject: [PATCH 060/105] linting --- tests/python/relax/test_from_exported_to_cuda.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index 43107f015313..0a120aa8fb70 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -71,13 +71,14 @@ def __init__(self): def forward(self, x): return torch.full((2, 3), 3.141592) - + torch_module = FullModel().eval() - raw_data = np.random.rand(3,3).astype("float32") + raw_data = np.random.rand(3, 3).astype("float32") assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + @tvm.testing.parametrize_targets("cuda") def test_tensor_clamp(target, dev): class ClampBothTensor(torch.nn.Module): From 5c0e18b7b419a8194c696d9e4c5f6194af7b251a Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 13 Apr 2025 20:31:20 -0400 Subject: [PATCH 061/105] ones ok --- .../frontend/torch/base_fx_graph_translator.py | 16 ++++++++++++++++ .../torch/exported_program_translator.py | 1 + python/tvm/relax/frontend/torch/fx_translator.py | 16 ---------------- tests/python/relax/test_from_exported_to_cuda.py | 16 ++++++++++++++++ 4 files changed, 33 insertions(+), 16 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 55a603e20c60..2a811fd33e1e 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1308,6 +1308,22 @@ def _new_ones(self, node: fx.Node) -> relax.Var: self_var.struct_info.dtype, ) ) + + def _ones(self, node: fx.Node) -> relax.Var: + import torch + + args = self.retrieve_args(node) + size = relax.ShapeExpr(args[0] if isinstance(args[0], (list, tuple)) else (args[0],)) + dtype = self._convert_data_type( + node.kwargs.get("dtype", torch.get_default_dtype()), self.env + ) + return self.block_builder.emit( + relax.op.full( + size, + relax.const(1, dtype), + dtype, + ) + ) ########## DataType ########## diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 26e73dd6b84b..e962fbdbc696 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -438,6 +438,7 @@ def create_convert_map( "lift_fresh_copy.default": self._to_copy, "new_ones.default": self._new_ones, "one_hot.default": self._one_hot, + "ones.default": self._ones, # datatype "to.dtype": self._to, "to.dtype_layout": self._to, diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 80031cd7a403..f1b9a6d6e28c 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -510,22 +510,6 @@ def _masked_scatter(self, node: fx.Node) -> relax.Var: mask = self.block_builder.emit(relax.op.broadcast_to(mask, x.struct_info.shape)) return self.block_builder.emit(relax.op.where(mask, gathered_source, x)) - def _ones(self, node: fx.Node) -> relax.Var: - import torch - - args = self.retrieve_args(node) - size = relax.ShapeExpr(args[0] if isinstance(args[0], (list, tuple)) else (args[0],)) - dtype = self._convert_data_type( - node.kwargs.get("dtype", torch.get_default_dtype()), self.env - ) - return self.block_builder.emit( - relax.op.full( - size, - relax.const(1, dtype), - dtype, - ) - ) - def _one_hot(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] num_classes = node.args[1] if len(node.args) > 1 else node.kwargs.get("num_classes") diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index 0a120aa8fb70..5a0435d44484 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -78,6 +78,22 @@ def forward(self, x): assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) +@tvm.testing.parametrize_targets("cuda") +def test_ones(target, dev): + class FullModel(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.ones((2, 3)) + + torch_module = FullModel().eval() + + raw_data = np.random.rand(1,1).astype("float32") + + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + @tvm.testing.parametrize_targets("cuda") def test_tensor_clamp(target, dev): From 40316a073d3554495874478f14a438e02005c045 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 13 Apr 2025 20:54:57 -0400 Subject: [PATCH 062/105] tests for ones, full, and full like work --- .../torch/base_fx_graph_translator.py | 7 ++++++- .../torch/exported_program_translator.py | 1 + .../relax/test_from_exported_to_cuda.py | 20 +++++++++++++++++-- 3 files changed, 25 insertions(+), 3 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 2a811fd33e1e..3018b0db771d 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1288,6 +1288,11 @@ def _full(self, node: fx.Node) -> relax.Var: ) ) + def _full_like(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + fill_value = relax.const(node.args[1]) + return self.block_builder.emit(relax.op.full_like(x, fill_value)) + def _index_select(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] dim = node.args[1] @@ -1308,7 +1313,7 @@ def _new_ones(self, node: fx.Node) -> relax.Var: self_var.struct_info.dtype, ) ) - + def _ones(self, node: fx.Node) -> relax.Var: import torch diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index e962fbdbc696..bcb8b6468f72 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -434,6 +434,7 @@ def create_convert_map( "empty_like.default": self._empty_like, "fill.Scalar": self._fill, "full.default": self._full, + "full_like.default": self._full_like, "index_select.default": self._index_select, "lift_fresh_copy.default": self._to_copy, "new_ones.default": self._new_ones, diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index 5a0435d44484..e92855885e35 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -78,6 +78,23 @@ def forward(self, x): assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + +@tvm.testing.parametrize_targets("cuda") +def test_full_like(target, dev): + class FullLike(nn.Module): + def __init__(self): + super().__init__() + self.fill_value = 7.0 + + def forward(self, x): + return torch.full_like(x, self.fill_value) + + torch_module = FullLike().eval() + raw_data = np.random.rand(2, 3).astype("float32") + + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + @tvm.testing.parametrize_targets("cuda") def test_ones(target, dev): class FullModel(nn.Module): @@ -89,12 +106,11 @@ def forward(self, x): torch_module = FullModel().eval() - raw_data = np.random.rand(1,1).astype("float32") + raw_data = np.random.rand(1, 1).astype("float32") assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) - @tvm.testing.parametrize_targets("cuda") def test_tensor_clamp(target, dev): class ClampBothTensor(torch.nn.Module): From ac33a594496973ed851f2d9a83f063940abe1b79 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 13 Apr 2025 22:31:22 -0400 Subject: [PATCH 063/105] before switchign to list[Expr] --- python/tvm/topi/transform.py | 106 +++++++++++++++++- .../relax/test_from_exported_to_cuda.py | 17 ++- .../relax/test_from_exported_to_cuda_NEW.py | 82 ++++++++++++++ 3 files changed, 201 insertions(+), 4 deletions(-) create mode 100644 tests/python/relax/test_from_exported_to_cuda_NEW.py diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py index 2f78edcd4145..78e71b863070 100644 --- a/python/tvm/topi/transform.py +++ b/python/tvm/topi/transform.py @@ -956,6 +956,17 @@ def adv_index(data, indices): result : tvm.te.Tensor Output tensor """ + + """ + TODO + this seems to be wrong + Does not achieve correctness with this: + + x np.random.rand(5,5,5,5).astype("float32") + return x[[[0,1],[0,1]]] + + """ + return cpp.adv_index(data, indices) @@ -1073,7 +1084,96 @@ def index_tensor(data, indices): 2. Remove exactly one leading dimension of size=1, if present. (Matches PyTorch's shape rule.) 3. Flatten -> fix negative indices -> index_select -> reshape. """ - # flattened = topi.reshape(indices, (-1,)) - picked = topi.take(data, indices, axis=0) - return picked + + + + if isinstance(indices, (list, tuple)) and len(indices) > 1: + + def _broadcast_shape(shapes): + """ + shapes: list of tuples + Return the broadcasted shape for these shapes + """ + max_ndim = max(len(s) for s in shapes) + out_rev = [] + # reverse each shape + rev_shapes = [s[::-1] for s in shapes] + for i in range(max_ndim): + dim_size = 1 + for rsh in rev_shapes: + if i < len(rsh): + s_ = rsh[i] + # typical broadcast rule + if s_ != 1 and dim_size != 1 and s_ != dim_size: + raise ValueError("Incompatible shapes for broadcast") + dim_size = max(dim_size, s_) + out_rev.append(dim_size) + out_rev.reverse() + return tuple(out_rev) + + shapes = [tuple(idx.shape) for idx in idx_list] + broadcast_shape = _broadcast_shape(shapes) + + # ------------------------------------------------- + # 2) Expand (broadcast) each index to shape B + # Then fix negative indices if you want negative support + # ------------------------------------------------- + expanded_idx_list = [] + for i, idx in enumerate(idx_list): + # broadcast to shape B + broadcasted = topi.broadcast_to(idx, broadcast_shape) + + # fix negative: out_idx = where(idx < 0, idx + data.shape[i], idx) + # data.shape[i] might be a PrimExpr or int + dim_size_i = data.shape[i] # dimension size for data's i-th dim + # We must make sure it's broadcast-compatible: + dim_size_t = topi.full_like(broadcasted, dim_size_i) + zero_t = topi.full_like(broadcasted, 0) + fixed = topi.where(topi.less(broadcasted, zero_t), + topi.add(broadcasted, dim_size_t), + broadcasted) + expanded_idx_list.append(fixed) + + # leftover dimensions => data.shape[k:] + k = len(idx_list) + leftover_dims = data.shape[k:] + # Final output shape is broadcast_shape + leftover_dims + final_shape = broadcast_shape + leftover_dims + + # ------------------------------------------------- + # 3) Build a te.compute that gathers from 'data' + # ------------------------------------------------- + def _compute(*args): + # 'args' is a multi-index into final_shape + # => the first len(broadcast_shape) are the broadcast coords + # the remaining correspond to leftover_dims + bdim = len(broadcast_shape) + leftover_dim = len(leftover_dims) + assert len(args) == bdim + leftover_dim + + # advanced_indices for dimension i + # i.e. i0 = expanded_idx_list[0][b0,b1,...], i1 = expanded_idx_list[1][b0,b1,...], ... + # leftover coords => the last leftover_dim from 'args' + # data is presumably shape = [D0, D1, ..., D(k-1), leftover...] + # So data coordinate is [ i0, i1, ..., i(k-1), leftover0, leftover1, ...] + data_coords = [] + for i_ in range(k): + data_coords.append(expanded_idx_list[i_][*args[:bdim]]) + # Now append leftover coords + data_coords.extend(args[bdim:]) + + return data(*data_coords) + + # The final te.compute + out = te.compute( + final_shape, + _compute, + name="multi_adv_index_gather", + ) + return out + + else: + # flattened = topi.reshape(indices, (-1,)) + picked = topi.take(data, indices, axis=0) + return picked diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index c7ff34698d76..b76999d781f1 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -196,7 +196,7 @@ def forward(self, x): return x[[[0,1,2,3], [1,2,3,4], [2,3,4,0]]] # both args[0] and indices are expr.Var torch_module = IndexTensorModel7().eval() - raw_data = np.random.rand(5,5,5).astype("float32") + raw_data = np.random.rand(5,5,5,5).astype("float32") assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) @@ -738,6 +738,21 @@ def forward(self, x): raw_data = np.random.rand(10, 10, 10).astype("float32") assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) +@tvm.testing.parametrize_targets("cuda") +def test_mul(target, dev): + class MulModule(nn.Module): + def __init__(self): + super().__init__() + self.y = torch.tensor(np.random.rand(2, 3).astype("float32")) + + def forward(self, x): + return x.mul(self.y) + + torch_module = MulModule().eval() + raw_data = np.random.rand(2, 3).astype("float32") + + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/test_from_exported_to_cuda_NEW.py b/tests/python/relax/test_from_exported_to_cuda_NEW.py new file mode 100644 index 000000000000..dd449c344012 --- /dev/null +++ b/tests/python/relax/test_from_exported_to_cuda_NEW.py @@ -0,0 +1,82 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +from tvm import relax +import tvm.testing +import numpy as np +import torch +from torch import nn +from torch.export import export +from tvm.relax.frontend.torch import from_exported_program +from torch.nn import Softmax, Upsample + + +def assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev): + """ + This util ensures that a torch module can successfully be exported to TVM + using torch.export and that the resuling IR program gives the same result + as PyTorch when ran on CUDA. + """ + raw_data_for_tvm = raw_data.copy() # In case the data is modified + torch_data = torch.from_numpy(raw_data) + example_args = (torch_data,) + + with torch.no_grad(): + exported_program = export(torch_module, example_args) + mod_from_torch = from_exported_program(exported_program, keep_params_as_input=True) + + tvm_mod, tvm_params = relax.frontend.detach_params(mod_from_torch) + + relax_pipeline = relax.get_default_pipeline(tvm.target.Target.from_device(tvm.cuda())) + ex = relax.build(tvm_mod, target=target, relax_pipeline=relax_pipeline) + vm = relax.VirtualMachine(ex, dev) + + gpu_data = tvm.nd.array(raw_data_for_tvm, dev) + gpu_params = [tvm.nd.array(p, dev) for p in tvm_params["main"]] + gpu_out = vm["main"](gpu_data, *gpu_params) + + pytorch_out = torch_module(torch_data) + + if isinstance(pytorch_out, tuple): + for i in range(len(pytorch_out)): + actual = gpu_out[i].numpy() + desired = pytorch_out[i].detach().numpy() + np.testing.assert_allclose(actual=actual, desired=desired, rtol=1e-5, atol=1e-5) + else: + actual = gpu_out[0].numpy() + desired = pytorch_out.detach().numpy() + np.testing.assert_allclose(actual=actual, desired=desired, rtol=1e-5, atol=1e-5) + + + +@tvm.testing.parametrize_targets("cuda") +def test_index_tensor6(target, dev): + class IndexTensorModel6(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x[[[0,1],[0,1]]] + + torch_module = IndexTensorModel6().eval() + raw_data = np.random.rand(5,5,5,5).astype("float32") + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + +if __name__ == "__main__": + tvm.testing.main() From eddcd392a976d069b3412aa06b1d3bb809517587 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 13 Apr 2025 23:52:18 -0400 Subject: [PATCH 064/105] able to get list of tensors in topi --- .../torch/base_fx_graph_translator.py | 2 +- python/tvm/relax/op/manipulate.py | 5 +++- .../transform/legalize_ops/manipulate.py | 25 ++++++++++++++++- python/tvm/topi/transform.py | 8 +++++- src/relax/op/tensor/manipulate.cc | 28 +++++-------------- 5 files changed, 43 insertions(+), 25 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 4cde5de90a86..ca8286cbf199 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1089,7 +1089,7 @@ def _index_tensor(self, node: fx.Node) -> relax.Var: # indices = args[1] # TODO do something like this! # indices = [2,3] - indices = args[1][0] + indices = args[1] print("type of indices", type(indices)) # print("indices:") # args_indices = self.retrieve_args(indices) diff --git a/python/tvm/relax/op/manipulate.py b/python/tvm/relax/op/manipulate.py index 159e496fa16c..da55ef30b2e9 100644 --- a/python/tvm/relax/op/manipulate.py +++ b/python/tvm/relax/op/manipulate.py @@ -506,7 +506,8 @@ def gather_nd(data: Expr, indices: Expr, batch_dims: int = 0) -> Expr: """ return _ffi_api.gather_nd(data, indices, batch_dims) # type: ignore -def index_tensor(data: Expr, indices: Expr) -> Expr: + +def index_tensor(data: Expr, indices: Union[Expr, List[Expr]]) -> Expr: """ TODO docstring """ @@ -515,6 +516,8 @@ def index_tensor(data: Expr, indices: Expr) -> Expr: # assert all(isinstance(i, int) for i in indices), "indices should be a list of integers, but got {}".format( # [type(i) for i in indices] # ) + if isinstance(indices, (list, tuple)): + indices = RxTuple(indices) return _ffi_api.index_tensor(data, indices) # type: ignore def scatter_elements( diff --git a/python/tvm/relax/transform/legalize_ops/manipulate.py b/python/tvm/relax/transform/legalize_ops/manipulate.py index d4aac7abc392..bd86b54e1fd9 100644 --- a/python/tvm/relax/transform/legalize_ops/manipulate.py +++ b/python/tvm/relax/transform/legalize_ops/manipulate.py @@ -162,9 +162,32 @@ def te_gather_nd(data, indices, batch_dims): return bb.call_te(te_gather_nd, call.args[0], call.args[1], int(call.attrs.batch_dims)) +# @register_legalize("relax.index_tensor") +# def _index_tensor(bb: BlockBuilder, call: Call) -> Expr: +# return bb.call_te(topi.index_tensor, call.args[0], call.args[1]) # TODO should I use primfunc_name_hint? + + + @register_legalize("relax.index_tensor") def _index_tensor(bb: BlockBuilder, call: Call) -> Expr: - return bb.call_te(topi.index_tensor, call.args[0], call.args[1]) # TODO should I use primfunc_name_hint? + t = call.args[1] + n_field = len(t.struct_info.fields) + while isinstance(t, Var): + binding = bb.lookup_binding(t) + if not isinstance(binding, (Tuple, Var)): + break + t = binding + + assert isinstance(t, (Tuple, Var)) + fields = ( + t.fields if isinstance(t, Tuple) else [bb.emit(TupleGetItem(t, i)) for i in range(n_field)] + ) + return bb.call_te( + topi.index_tensor, call.args[0], fields + ) + + + @register_legalize("relax.scatter_elements") def _scatter_elements(bb: BlockBuilder, call: Call) -> Expr: diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py index 78e71b863070..b88681df35d0 100644 --- a/python/tvm/topi/transform.py +++ b/python/tvm/topi/transform.py @@ -1085,7 +1085,13 @@ def index_tensor(data, indices): 3. Flatten -> fix negative indices -> index_select -> reshape. """ - + print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") + print("WE ARE IN TOPI~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") + print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") + print("type of data", type(data)) + print("type of indices", type(indices)) + print("data", data) + print("indices", indices) if isinstance(indices, (list, tuple)) and len(indices) > 1: diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index dd6a8bd98449..dc08f01351cd 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -477,6 +477,10 @@ TVM_REGISTER_OP("relax.flatten") /* relax.index_tensor */ Expr index_tensor(Expr data, Expr indices) { + // TODO do we need code below? + // ObjectPtr attrs = make_object(); + // attrs->indices = std::move(indices); + static const Op& op = Op::Get("relax.index_tensor"); return Call(op, {std::move(data), std::move(indices)}, Attrs(), {}); } @@ -484,35 +488,17 @@ Expr index_tensor(Expr data, Expr indices) { TVM_REGISTER_GLOBAL("relax.op.index_tensor").set_body_typed(index_tensor); StructInfo InferStructInfoIndexTensor(const Call& call, const BlockBuilder& ctx) { - // TODO most of this is arbitrarily copied from collapse_sum_like. Need to understand what we - // actually need - Array input_sinfo = GetInputTensorStructInfo(call, ctx); - TensorStructInfo data_sinfo = input_sinfo[0]; - TensorStructInfo indices_sinfo = input_sinfo[1]; + TensorStructInfo data_sinfo = GetInputTensorStructInfo(call, 0, ctx); DataType output_dtype = data_sinfo->dtype; - Optional> data_shape_value; - if (data_sinfo->shape.defined()) { - data_shape_value = GetStructInfoAs(data_sinfo->shape.value())->values; - } - Optional> indices_shape_value; - if (indices_sinfo->shape.defined()) { - indices_shape_value = - GetStructInfoAs(indices_sinfo->shape.value())->values; - } - - if (indices_sinfo->shape.defined()) { - return TensorStructInfo(indices_sinfo->shape.value(), output_dtype, indices_sinfo->vdevice); - } else { - return TensorStructInfo(output_dtype, indices_sinfo->ndim, indices_sinfo->vdevice); - } + return TensorStructInfo(output_dtype, kUnknownNDim); } TVM_REGISTER_OP("relax.index_tensor") .set_num_inputs(2) .add_argument("data", "Tensor", "The input tensor.") - .add_argument("indices", "Tensor", "The indices tensor.") + .add_argument("indices", "Tuple of Tensor", "The indices tensor.") .set_attr("FInferStructInfo", InferStructInfoIndexTensor) .set_attr("FPurity", Bool(true)); From 6f62e0d7241fe63a6e05effc21ad2de85fb65ca7 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Mon, 14 Apr 2025 00:07:20 -0400 Subject: [PATCH 065/105] unable to reproduce results for second case --- python/tvm/topi/transform.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py index b88681df35d0..46f2a4d34c27 100644 --- a/python/tvm/topi/transform.py +++ b/python/tvm/topi/transform.py @@ -1093,7 +1093,14 @@ def index_tensor(data, indices): print("data", data) print("indices", indices) + is_instance = isinstance(indices, (list, tuple)) + print("isinstance(indices, (list, tuple))", is_instance) + if is_instance: + print("len(indices)", len(indices)) + + if isinstance(indices, (list, tuple)) and len(indices) > 1: + print("IF CASE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") def _broadcast_shape(shapes): """ @@ -1180,6 +1187,12 @@ def _compute(*args): else: # flattened = topi.reshape(indices, (-1,)) - picked = topi.take(data, indices, axis=0) + print("ELSE CASE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") + idxs = indices[0] + print("type(data)", type(data)) + print("data", data) + print("type(idxs)",type(idxs)) + print("idxs", idxs) + picked = topi.take(data, idxs, axis=0) return picked From 65c6ba2662a993daedc12f8648a24ba83b9325b3 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Mon, 14 Apr 2025 00:38:20 -0400 Subject: [PATCH 066/105] not working --- .../transform/legalize_ops/manipulate.py | 2 +- python/tvm/topi/transform.py | 222 +++++++++--------- 2 files changed, 113 insertions(+), 111 deletions(-) diff --git a/python/tvm/relax/transform/legalize_ops/manipulate.py b/python/tvm/relax/transform/legalize_ops/manipulate.py index bd86b54e1fd9..4493135e753d 100644 --- a/python/tvm/relax/transform/legalize_ops/manipulate.py +++ b/python/tvm/relax/transform/legalize_ops/manipulate.py @@ -164,7 +164,7 @@ def te_gather_nd(data, indices, batch_dims): # @register_legalize("relax.index_tensor") # def _index_tensor(bb: BlockBuilder, call: Call) -> Expr: -# return bb.call_te(topi.index_tensor, call.args[0], call.args[1]) # TODO should I use primfunc_name_hint? +# return bb.call_te(topi.index_tensor, call.args[0], call.args[1][0]) diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py index 46f2a4d34c27..40f2c745dc36 100644 --- a/python/tvm/topi/transform.py +++ b/python/tvm/topi/transform.py @@ -1085,114 +1085,116 @@ def index_tensor(data, indices): 3. Flatten -> fix negative indices -> index_select -> reshape. """ - print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") - print("WE ARE IN TOPI~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") - print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") - print("type of data", type(data)) - print("type of indices", type(indices)) - print("data", data) - print("indices", indices) - - is_instance = isinstance(indices, (list, tuple)) - print("isinstance(indices, (list, tuple))", is_instance) - if is_instance: - print("len(indices)", len(indices)) - - - if isinstance(indices, (list, tuple)) and len(indices) > 1: - print("IF CASE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") - - def _broadcast_shape(shapes): - """ - shapes: list of tuples - Return the broadcasted shape for these shapes - """ - max_ndim = max(len(s) for s in shapes) - out_rev = [] - # reverse each shape - rev_shapes = [s[::-1] for s in shapes] - for i in range(max_ndim): - dim_size = 1 - for rsh in rev_shapes: - if i < len(rsh): - s_ = rsh[i] - # typical broadcast rule - if s_ != 1 and dim_size != 1 and s_ != dim_size: - raise ValueError("Incompatible shapes for broadcast") - dim_size = max(dim_size, s_) - out_rev.append(dim_size) - out_rev.reverse() - return tuple(out_rev) - - shapes = [tuple(idx.shape) for idx in idx_list] - broadcast_shape = _broadcast_shape(shapes) - - # ------------------------------------------------- - # 2) Expand (broadcast) each index to shape B - # Then fix negative indices if you want negative support - # ------------------------------------------------- - expanded_idx_list = [] - for i, idx in enumerate(idx_list): - # broadcast to shape B - broadcasted = topi.broadcast_to(idx, broadcast_shape) - - # fix negative: out_idx = where(idx < 0, idx + data.shape[i], idx) - # data.shape[i] might be a PrimExpr or int - dim_size_i = data.shape[i] # dimension size for data's i-th dim - # We must make sure it's broadcast-compatible: - dim_size_t = topi.full_like(broadcasted, dim_size_i) - zero_t = topi.full_like(broadcasted, 0) - fixed = topi.where(topi.less(broadcasted, zero_t), - topi.add(broadcasted, dim_size_t), - broadcasted) - expanded_idx_list.append(fixed) - - # leftover dimensions => data.shape[k:] - k = len(idx_list) - leftover_dims = data.shape[k:] - # Final output shape is broadcast_shape + leftover_dims - final_shape = broadcast_shape + leftover_dims - - # ------------------------------------------------- - # 3) Build a te.compute that gathers from 'data' - # ------------------------------------------------- - def _compute(*args): - # 'args' is a multi-index into final_shape - # => the first len(broadcast_shape) are the broadcast coords - # the remaining correspond to leftover_dims - bdim = len(broadcast_shape) - leftover_dim = len(leftover_dims) - assert len(args) == bdim + leftover_dim - - # advanced_indices for dimension i - # i.e. i0 = expanded_idx_list[0][b0,b1,...], i1 = expanded_idx_list[1][b0,b1,...], ... - # leftover coords => the last leftover_dim from 'args' - # data is presumably shape = [D0, D1, ..., D(k-1), leftover...] - # So data coordinate is [ i0, i1, ..., i(k-1), leftover0, leftover1, ...] - data_coords = [] - for i_ in range(k): - data_coords.append(expanded_idx_list[i_][*args[:bdim]]) - # Now append leftover coords - data_coords.extend(args[bdim:]) - - return data(*data_coords) - - # The final te.compute - out = te.compute( - final_shape, - _compute, - name="multi_adv_index_gather", - ) - return out - - else: - # flattened = topi.reshape(indices, (-1,)) - print("ELSE CASE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") - idxs = indices[0] - print("type(data)", type(data)) - print("data", data) - print("type(idxs)",type(idxs)) - print("idxs", idxs) - picked = topi.take(data, idxs, axis=0) - return picked + return topi.adv_index(data, indices) + + # print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") + # print("WE ARE IN TOPI~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") + # print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") + # print("type of data", type(data)) + # print("type of indices", type(indices)) + # print("data", data) + # print("indices", indices) + + # is_instance = isinstance(indices, (list, tuple)) + # print("isinstance(indices, (list, tuple))", is_instance) + # if is_instance: + # print("len(indices)", len(indices)) + + + # if isinstance(indices, (list, tuple)) and len(indices) > 1: + # print("IF CASE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") + + # def _broadcast_shape(shapes): + # """ + # shapes: list of tuples + # Return the broadcasted shape for these shapes + # """ + # max_ndim = max(len(s) for s in shapes) + # out_rev = [] + # # reverse each shape + # rev_shapes = [s[::-1] for s in shapes] + # for i in range(max_ndim): + # dim_size = 1 + # for rsh in rev_shapes: + # if i < len(rsh): + # s_ = rsh[i] + # # typical broadcast rule + # if s_ != 1 and dim_size != 1 and s_ != dim_size: + # raise ValueError("Incompatible shapes for broadcast") + # dim_size = max(dim_size, s_) + # out_rev.append(dim_size) + # out_rev.reverse() + # return tuple(out_rev) + + # shapes = [tuple(idx.shape) for idx in idx_list] + # broadcast_shape = _broadcast_shape(shapes) + + # # ------------------------------------------------- + # # 2) Expand (broadcast) each index to shape B + # # Then fix negative indices if you want negative support + # # ------------------------------------------------- + # expanded_idx_list = [] + # for i, idx in enumerate(idx_list): + # # broadcast to shape B + # broadcasted = topi.broadcast_to(idx, broadcast_shape) + + # # fix negative: out_idx = where(idx < 0, idx + data.shape[i], idx) + # # data.shape[i] might be a PrimExpr or int + # dim_size_i = data.shape[i] # dimension size for data's i-th dim + # # We must make sure it's broadcast-compatible: + # dim_size_t = topi.full_like(broadcasted, dim_size_i) + # zero_t = topi.full_like(broadcasted, 0) + # fixed = topi.where(topi.less(broadcasted, zero_t), + # topi.add(broadcasted, dim_size_t), + # broadcasted) + # expanded_idx_list.append(fixed) + + # # leftover dimensions => data.shape[k:] + # k = len(idx_list) + # leftover_dims = data.shape[k:] + # # Final output shape is broadcast_shape + leftover_dims + # final_shape = broadcast_shape + leftover_dims + + # # ------------------------------------------------- + # # 3) Build a te.compute that gathers from 'data' + # # ------------------------------------------------- + # def _compute(*args): + # # 'args' is a multi-index into final_shape + # # => the first len(broadcast_shape) are the broadcast coords + # # the remaining correspond to leftover_dims + # bdim = len(broadcast_shape) + # leftover_dim = len(leftover_dims) + # assert len(args) == bdim + leftover_dim + + # # advanced_indices for dimension i + # # i.e. i0 = expanded_idx_list[0][b0,b1,...], i1 = expanded_idx_list[1][b0,b1,...], ... + # # leftover coords => the last leftover_dim from 'args' + # # data is presumably shape = [D0, D1, ..., D(k-1), leftover...] + # # So data coordinate is [ i0, i1, ..., i(k-1), leftover0, leftover1, ...] + # data_coords = [] + # for i_ in range(k): + # data_coords.append(expanded_idx_list[i_][*args[:bdim]]) + # # Now append leftover coords + # data_coords.extend(args[bdim:]) + + # return data(*data_coords) + + # # The final te.compute + # out = te.compute( + # final_shape, + # _compute, + # name="multi_adv_index_gather", + # ) + # return out + + # else: + # # flattened = topi.reshape(indices, (-1,)) + # print("ELSE CASE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") + # idxs = indices[0] + # print("type(data)", type(data)) + # print("data", data) + # print("type(idxs)",type(idxs)) + # print("idxs", idxs) + # picked = topi.take(data, idxs, axis=0) + # return picked From 85737ec5f2f6748a2380f77c3450f6031b828eaf Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Mon, 14 Apr 2025 00:44:47 -0400 Subject: [PATCH 067/105] dummy concat works --- tests/python/relax/test_from_exported_OLD.py | 176 ++++++++++++++++++ .../python/relax/test_from_exported_concat.py | 86 +++++++++ 2 files changed, 262 insertions(+) create mode 100644 tests/python/relax/test_from_exported_OLD.py create mode 100644 tests/python/relax/test_from_exported_concat.py diff --git a/tests/python/relax/test_from_exported_OLD.py b/tests/python/relax/test_from_exported_OLD.py new file mode 100644 index 000000000000..249e0d6b86a6 --- /dev/null +++ b/tests/python/relax/test_from_exported_OLD.py @@ -0,0 +1,176 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +from tvm import relax +import tvm.testing +import numpy as np +import torch +from torch import nn +from torch.export import export +from tvm.relax.frontend.torch import from_exported_program +from torch.nn import Softmax, Upsample + + +def assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev): + """ + This util ensures that a torch module can successfully be exported to TVM + using torch.export and that the resuling IR program gives the same result + as PyTorch when ran on CUDA. + """ + raw_data_for_tvm = raw_data.copy() # In case the data is modified + torch_data = torch.from_numpy(raw_data) + example_args = (torch_data,) + + with torch.no_grad(): + exported_program = export(torch_module, example_args) + mod_from_torch = from_exported_program(exported_program, keep_params_as_input=True) + + tvm_mod, tvm_params = relax.frontend.detach_params(mod_from_torch) + + relax_pipeline = relax.get_default_pipeline(tvm.target.Target.from_device(tvm.cuda())) + ex = relax.build(tvm_mod, target=target, relax_pipeline=relax_pipeline) + vm = relax.VirtualMachine(ex, dev) + + gpu_data = tvm.nd.array(raw_data_for_tvm, dev) + gpu_params = [tvm.nd.array(p, dev) for p in tvm_params["main"]] + gpu_out = vm["main"](gpu_data, *gpu_params) + + pytorch_out = torch_module(torch_data) + + if isinstance(pytorch_out, tuple): + for i in range(len(pytorch_out)): + actual = gpu_out[i].numpy() + desired = pytorch_out[i].detach().numpy() + np.testing.assert_allclose(actual=actual, desired=desired, rtol=1e-5, atol=1e-5) + else: + actual = gpu_out[0].numpy() + desired = pytorch_out.detach().numpy() + np.testing.assert_allclose(actual=actual, desired=desired, rtol=1e-5, atol=1e-5) + + + +# @tvm.testing.parametrize_targets("cuda") +# def test_full(target, dev): +# class FullModel(nn.Module): +# def __init__(self): +# super().__init__() + +# def forward(self, x): +# return torch.full((2, 3), 3.141592) + +# torch_module = FullModel().eval() + +# raw_data = np.random.rand(3,3).astype("float32") + +# assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + +# Test index.Tensor # TODO aggregate into one big tet + +# @tvm.testing.parametrize_targets("cuda") +# def test_index_tensor0(target, dev): +# class IndexModel0(nn.Module): +# def __init__(self): +# super().__init__() + +# def forward(self, x): +# return x[torch.tensor([0])] + +# torch_module = IndexModel0().eval() + +# raw_data = np.random.rand(3,3).astype("float32") + +# assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + +# @tvm.testing.parametrize_targets("cuda") +# def test_index_tensor1(target, dev): +# class IndexModel1(nn.Module): +# def __init__(self): +# super().__init__() + +# def forward(self, x): +# return x[torch.tensor([[0]])] + +# torch_module = IndexModel1().eval() + +# raw_data = np.random.rand(2,3).astype("float32") + +# assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + +# @tvm.testing.parametrize_targets("cuda") +# def test_index_tensor2(target, dev): +# class IndexTensorModel2(nn.Module): +# def __init__(self): +# super().__init__() + +# def forward(self, x): +# return x[torch.tensor([0,2])] + +# torch_module = IndexTensorModel2().eval() + +# raw_data = np.random.rand(3,4).astype("float32") + +# assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) +# assert 0 + + +# @tvm.testing.parametrize_targets("cuda") +# def test_index_tensor3(target, dev): +# class IndexTensorModel3(nn.Module): +# def __init__(self): +# super().__init__() + +# def forward(self, x): +# return x[[[[0,2],[1,3]]]] + +# torch_module = IndexTensorModel3().eval() +# raw_data = np.random.rand(5,5,5).astype("float32") +# assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + +# @tvm.testing.parametrize_targets("cuda") +# def test_index_tensor4(target, dev): +# class IndexTensorModel4(nn.Module): +# def __init__(self): +# super().__init__() + +# def forward(self, x): +# return x[[[1,4]]] + +# torch_module = IndexTensorModel4().eval() +# raw_data = np.random.rand(5,5,5).astype("float32") +# assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + +# @tvm.testing.parametrize_targets("cuda") +# def test_index_tensor5(target, dev): +# class IndexTensorModel5(nn.Module): +# def __init__(self): +# super().__init__() + +# def forward(self, x): +# return x[[[[1,2,4]]]] + +# torch_module = IndexTensorModel5().eval() +# raw_data = np.random.rand(5,5,5).astype("float32") +# assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_from_exported_concat.py b/tests/python/relax/test_from_exported_concat.py new file mode 100644 index 000000000000..8103bf7f53a0 --- /dev/null +++ b/tests/python/relax/test_from_exported_concat.py @@ -0,0 +1,86 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +from tvm import relax +import tvm.testing +import numpy as np +import torch +from torch import nn +from torch.export import export +from tvm.relax.frontend.torch import from_exported_program +from torch.nn import Softmax, Upsample + + +def assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev): + """ + This util ensures that a torch module can successfully be exported to TVM + using torch.export and that the resuling IR program gives the same result + as PyTorch when ran on CUDA. + """ + raw_data_for_tvm = raw_data.copy() # In case the data is modified + torch_data = torch.from_numpy(raw_data) + example_args = (torch_data,) + + with torch.no_grad(): + exported_program = export(torch_module, example_args) + mod_from_torch = from_exported_program(exported_program, keep_params_as_input=True) + + tvm_mod, tvm_params = relax.frontend.detach_params(mod_from_torch) + + relax_pipeline = relax.get_default_pipeline(tvm.target.Target.from_device(tvm.cuda())) + ex = relax.build(tvm_mod, target=target, relax_pipeline=relax_pipeline) + vm = relax.VirtualMachine(ex, dev) + + gpu_data = tvm.nd.array(raw_data_for_tvm, dev) + gpu_params = [tvm.nd.array(p, dev) for p in tvm_params["main"]] + gpu_out = vm["main"](gpu_data, *gpu_params) + + pytorch_out = torch_module(torch_data) + + if isinstance(pytorch_out, tuple): + for i in range(len(pytorch_out)): + actual = gpu_out[i].numpy() + desired = pytorch_out[i].detach().numpy() + np.testing.assert_allclose(actual=actual, desired=desired, rtol=1e-5, atol=1e-5) + else: + actual = gpu_out[0].numpy() + desired = pytorch_out.detach().numpy() + np.testing.assert_allclose(actual=actual, desired=desired, rtol=1e-5, atol=1e-5) + +@tvm.testing.parametrize_targets("cuda") +def test_index_tensor2(target, dev): + class ConcatFour(nn.Module): + def __init__(self, dim=1): + super(ConcatFour, self).__init__() + self.dim = dim + self.x2 = torch.randn(2, 3) + self.x3 = torch.randn(2, 3) + self.x4 = torch.randn(2, 3) + + def forward(self, x): + return torch.cat((x ,self.x2, self.x3, self.x4), dim=self.dim) + + torch_module = ConcatFour().eval() + + raw_data = np.random.rand(2,3).astype("float32") + + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + +if __name__ == "__main__": + tvm.testing.main() From 64f929738af8b1f2acfdebe2f36e3ccbbff0a5b4 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Mon, 14 Apr 2025 02:12:10 -0400 Subject: [PATCH 068/105] concat2 doesn't work either --- .../torch/base_fx_graph_translator.py | 8 + .../torch/exported_program_translator.py | 4 +- python/tvm/relax/op/__init__.py | 1 + python/tvm/relax/op/manipulate.py | 7 + .../transform/legalize_ops/manipulate.py | 22 ++ python/tvm/topi/transform.py | 20 ++ src/relax/op/tensor/manipulate.cc | 231 ++++++++++++++++++ src/topi/transform.cc | 4 + .../python/relax/test_from_exported_concat.py | 3 +- 9 files changed, 297 insertions(+), 3 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index ca8286cbf199..340a243e0998 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -987,6 +987,14 @@ def _cat(self, node: fx.Node) -> relax.Var: axis = args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0) return self.block_builder.emit(relax.op.concat(args[0], axis=axis)) + def _cat2(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + first_tensor = args[0][0] + other_tensors = args[0][1:] + print("base_fx_graph_translator: type(first_tensor)", type(first_tensor)) + print("base_fx_graph_translator: type(other_tensors)", type(other_tensors)) + return self.block_builder.emit(relax.op.concat2(first_tensor, other_tensors)) + def _chunk(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] chunks = node.args[1] diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index a5a76c77e921..c3d17e57b7ad 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -390,10 +390,10 @@ def create_convert_map( "where.self": self._where, # tensor manipulation "argsort.default": self._argsort, - "cat.default": self._cat, + "cat.default": self._cat2, "chunk.default": self._chunk, "clamp.Tensor": self._clamp, - "concat.default": self._cat, + "concat.default": self._cat2, "copy_.default": self._copy_, "cumsum.default": self._cumsum, "cumprod.default": self._cumprod, diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py index f81c5448a024..f9a92a9549da 100644 --- a/python/tvm/relax/op/__init__.py +++ b/python/tvm/relax/op/__init__.py @@ -89,6 +89,7 @@ collapse_sum_like, collapse_sum_to, concat, + concat2, expand_dims, flatten, flip, diff --git a/python/tvm/relax/op/manipulate.py b/python/tvm/relax/op/manipulate.py index da55ef30b2e9..457cf24b8b31 100644 --- a/python/tvm/relax/op/manipulate.py +++ b/python/tvm/relax/op/manipulate.py @@ -70,6 +70,13 @@ def concat(tensors: Union[Expr, List[Expr]], axis: Optional[int] = 0) -> Expr: tensors = RxTuple(tensors) return _ffi_api.concat(tensors, axis) # type: ignore +def concat2(first: Expr, tensors: Union[Expr, List[Expr]]) -> Expr: + if isinstance(tensors, (list, tuple)): + tensors = RxTuple(tensors) + print("manipulate.py: type(first)", type(first)) + print("manipulate.py: type(tensors)", type(tensors)) + return _ffi_api.concat2(first, tensors) # type: ignore + def expand_dims(x: Expr, axis: Union[int, List[int]]) -> Expr: """Insert new axes at the positions given by `axis`. diff --git a/python/tvm/relax/transform/legalize_ops/manipulate.py b/python/tvm/relax/transform/legalize_ops/manipulate.py index 4493135e753d..abe16ef31151 100644 --- a/python/tvm/relax/transform/legalize_ops/manipulate.py +++ b/python/tvm/relax/transform/legalize_ops/manipulate.py @@ -72,6 +72,28 @@ def _concat(bb: BlockBuilder, call: Call) -> Expr: ) + +@register_legalize("relax.concat2") +def _concat2(bb: BlockBuilder, call: Call) -> Expr: + assert 0 + first = call.args[0] + t = call.args[1] + n_field = len(t.struct_info.fields) + while isinstance(t, Var): + binding = bb.lookup_binding(t) + if not isinstance(binding, (Tuple, Var)): + break + t = binding + + assert isinstance(t, (Tuple, Var)) + fields = ( + t.fields if isinstance(t, Tuple) else [bb.emit(TupleGetItem(t, i)) for i in range(n_field)] + ) + return bb.call_te( + topi.concatenate2, first, fields + ) + + @register_legalize("relax.expand_dims") def _expand_dims(bb: BlockBuilder, call: Call) -> Expr: def te_expand_dims(data, axis): diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py index 40f2c745dc36..7050ca1c5a66 100644 --- a/python/tvm/topi/transform.py +++ b/python/tvm/topi/transform.py @@ -402,6 +402,26 @@ def concatenate(a_tuple, axis=0): """ return cpp.concatenate(a_tuple, axis) +def concatenate2(first, a_tuple): + """Join a sequence of arrays along an existing axis. + + Parameters + ---------- + a_tuple : tuple of tvm.te.Tensor + The arrays to concatenate + + axis : int, optional + The axis along which the arrays will be joined. Default is 0. + + Returns + ------- + ret : tvm.te.Tensor + """ + original_list = [first, *a_tuple] + return cpp.concatenate(original_list, 0) + + + def stack(a, axis): """Repeats the whole array multiple times. diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index dc08f01351cd..27a827b5a66d 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -330,6 +330,237 @@ TVM_REGISTER_OP("relax.concat") .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", Bool(true)); +/* relax.concat2 */ +#include + +Expr concat2(Expr first, Expr tensors) { + assert(0); + static const Op& op = Op::Get("relax.concat2"); + return Call(op, {std::move(first), std::move(tensors)}, Attrs(), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.concat2").set_body_typed(concat2); + +Optional> CheckConcatOutputShape2( + const Call& call, const BlockBuilder& ctx, const std::vector>& shape_values) { + assert(0); + + PrimExpr concat_sum = [&]() { + PrimExpr first_concat_dim = shape_values[0][0]; + return first_concat_dim * IntImm(DataType::Int(64), shape_values.size()); + }(); + + Array output_shape = shape_values[0]; + output_shape.Set(0, concat_sum); + return output_shape; +} + +StructInfo InferStructInfoConcat2(const Call& call, const BlockBuilder& ctx) { + print("HERE!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!"); + TensorStructInfo first_sinfo = GetInputTensorStructInfo(call, 0, ctx); + Array tensor_sinfo = GetTensorStructInfoFromTuple(call, ctx, call->args[1]); + + if (tensor_sinfo.empty()) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Concat op expects at least one tensor in the input Tuple. However, the " + "given input Tuple is empty."); + } + + const auto* attrs = call->attrs.as(); + int output_ndim = attrs->axis.defined() ? kUnknownNDim : 1; + DataType output_dtype = DataType::Void(); + Optional vdev = NullOpt; + bool shape_unknown = false; + bool is_void_dtype = false; + bool vdevice_unknown = false; + std::vector> shape_values; + shape_values.reserve(tensor_sinfo.size()); + + // First iteration with first_sinfo + if (first_sinfo->dtype.is_void()) { + is_void_dtype = true; + } else if (output_dtype.is_void()) { + output_dtype = first_sinfo->dtype; + } else if (first_sinfo->dtype != output_dtype) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Concat expects all input tensors to have the same dtype. However, the " + "input contains tensors with dtype " + << output_dtype << " and " << first_sinfo->dtype); + } + + // Update the output ndim. + // Todo(relax-team): revisit here for better check on if the input tensor has + // ndim 1 when the input axis is undefined. + if (output_ndim == kUnknownNDim) { + output_ndim = first_sinfo->ndim; + } else if (first_sinfo->ndim != kUnknownNDim && first_sinfo->ndim != output_ndim) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Concat expects all input tensors to have same ndim. However, the " + "input contains tensors with ndim " + << output_ndim << " and " << first_sinfo->ndim); + } + + // Update the virtual device. + if (!vdevice_unknown) { + if (first_sinfo->vdevice.defined()) { + if (!vdev.defined()) { + vdev = first_sinfo->vdevice.value(); + } else if (first_sinfo->vdevice.value()->target.defined()) { + // mismatch + if (first_sinfo->vdevice.value() != vdev) { + vdevice_unknown = true; + } + } + } + } + + // Update the shape values for best effort check. + const auto* shape_expr = first_sinfo->shape.as(); + if (shape_expr != nullptr) { + shape_values.push_back(shape_expr->values); + } else { + shape_unknown = true; + + if (!first_sinfo->shape.defined()) { + } else { + // Keep the shape value for equality check. + ShapeStructInfo shape_sinfo = + Downcast(first_sinfo->shape.value()->struct_info_); + if (shape_sinfo->values.defined()) { + shape_values.push_back(shape_sinfo->values.value()); + } + } + } + + for (TensorStructInfo sinfo : tensor_sinfo) { + // Update the output dtype. + if (sinfo->dtype.is_void()) { + is_void_dtype = true; + } else if (output_dtype.is_void()) { + output_dtype = sinfo->dtype; + } else if (sinfo->dtype != output_dtype) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Concat expects all input tensors to have the same dtype. However, the " + "input contains tensors with dtype " + << output_dtype << " and " << sinfo->dtype); + } + + // Update the output ndim. + // Todo(relax-team): revisit here for better check on if the input tensor has + // ndim 1 when the input axis is undefined. + if (output_ndim == kUnknownNDim) { + output_ndim = sinfo->ndim; + } else if (sinfo->ndim != kUnknownNDim && sinfo->ndim != output_ndim) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Concat expects all input tensors to have same ndim. However, the " + "input contains tensors with ndim " + << output_ndim << " and " << sinfo->ndim); + } + + // Update the virtual device. + if (!vdevice_unknown) { + if (sinfo->vdevice.defined()) { + if (!vdev.defined()) { + vdev = sinfo->vdevice.value(); + } else if (sinfo->vdevice.value()->target.defined()) { + // mismatch + if (sinfo->vdevice.value() != vdev) { + vdevice_unknown = true; + } + } + } + } + + // Update the shape values for best effort check. + const auto* shape_expr = sinfo->shape.as(); + if (shape_expr != nullptr) { + shape_values.push_back(shape_expr->values); + continue; + } + shape_unknown = true; + + if (!sinfo->shape.defined()) { + continue; + } + // Keep the shape value for equality check. + ShapeStructInfo shape_sinfo = Downcast(sinfo->shape.value()->struct_info_); + if (shape_sinfo->values.defined()) { + shape_values.push_back(shape_sinfo->values.value()); + } + } + + if (is_void_dtype) { + output_dtype = DataType::Void(); + } + if (vdevice_unknown) { + vdev = NullOpt; + } + + if (output_ndim == kUnknownNDim) { + return tensor_sinfo.size() == 1 ? tensor_sinfo[0] + : TensorStructInfo(output_dtype, output_ndim, vdev); + } + + // If there is only one input tensor, no action is needed. + if (tensor_sinfo.size() == 1) { + return tensor_sinfo[0]; + } + if (shape_values.empty()) { + if (!vdevice_unknown) { + return TensorStructInfo(output_dtype, output_ndim, vdev); + } + return TensorStructInfo(output_dtype, output_ndim); + } + + // As long as the there is known shape value, we will do the best effort check to ensure safety. + Optional> output_shape = CheckConcatOutputShape2(call, ctx, shape_values); + + if (shape_unknown || !output_shape.defined()) { + if (!vdevice_unknown) { + return TensorStructInfo(output_dtype, output_ndim, vdev); + } + return TensorStructInfo(output_dtype, output_ndim); + } else { + if (!vdevice_unknown) { + return TensorStructInfo(ShapeExpr(output_shape.value()), output_dtype, vdev); + } + return TensorStructInfo(ShapeExpr(output_shape.value()), output_dtype); + } +} + +InferLayoutOutput InferLayoutConcat2(const Call& call, + const Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { + ICHECK(NoDesiredLayout(call, desired_layouts)); + + const auto* attrs = call->attrs.as(); + ICHECK(attrs != nullptr) << "Invalid Call"; + NLayout nlayout = GetNLayout(var_layout_map, call->args[1]); + ICHECK(nlayout.IsNested()); + ICHECK(nlayout.NestedArray()[0].IsLeaf()); + + int n_tensor = nlayout.NestedArray().size(); + LayoutDecision layout = nlayout.NestedArray()[0].LeafValue(); + Array input_layouts, output_layouts; + for (int i = 0; i < n_tensor; ++i) { + input_layouts.push_back(layout); + } + output_layouts.push_back(layout); + ObjectPtr new_attrs = make_object(*attrs); + new_attrs->axis = Integer(FindAxis(layout->layout, attrs->axis.value_or(0)->value)); + return InferLayoutOutput({NLayout(input_layouts)}, output_layouts, Attrs(new_attrs)); +} + +TVM_REGISTER_OP("relax.concat2") + // .set_attrs_type() + // .set_num_inputs(1) + .add_argument("first", "Tensor", "The first tensor.") + .add_argument("tensors", "Tuple of Tensors", "The input list of tensors.") + .set_attr("FInferStructInfo", InferStructInfoConcat2) + .set_attr("FRelaxInferLayout", InferLayoutConcat2) + .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) + .set_attr("FPurity", Bool(true)); + /* relax.expand_dims */ TVM_REGISTER_NODE_TYPE(ExpandDimsAttrs); diff --git a/src/topi/transform.cc b/src/topi/transform.cc index 7ef63a9b3f56..96be2a9bb56f 100644 --- a/src/topi/transform.cc +++ b/src/topi/transform.cc @@ -70,6 +70,10 @@ TVM_REGISTER_GLOBAL("topi.concatenate").set_body([](TVMArgs args, TVMRetValue* r *rv = concatenate(args[0], args[1]); }); +TVM_REGISTER_GLOBAL("topi.concatenate2").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = concatenate(args[0], args[1]); +}); + TVM_REGISTER_GLOBAL("topi.stack").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = stack(args[0], args[1]); }); diff --git a/tests/python/relax/test_from_exported_concat.py b/tests/python/relax/test_from_exported_concat.py index 8103bf7f53a0..f006447f187c 100644 --- a/tests/python/relax/test_from_exported_concat.py +++ b/tests/python/relax/test_from_exported_concat.py @@ -65,7 +65,7 @@ def assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, tar @tvm.testing.parametrize_targets("cuda") def test_index_tensor2(target, dev): class ConcatFour(nn.Module): - def __init__(self, dim=1): + def __init__(self, dim=0): super(ConcatFour, self).__init__() self.dim = dim self.x2 = torch.randn(2, 3) @@ -80,6 +80,7 @@ def forward(self, x): raw_data = np.random.rand(2,3).astype("float32") assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + assert 0 if __name__ == "__main__": From aa4cfd85329d3a6c002463f573fc929659c7a28f Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Mon, 14 Apr 2025 02:22:47 -0400 Subject: [PATCH 069/105] original concat works --- .../tvm/relax/frontend/torch/exported_program_translator.py | 4 ++-- src/relax/op/tensor/manipulate.cc | 2 +- tests/python/relax/test_from_exported_concat.py | 1 - 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index c3d17e57b7ad..a5a76c77e921 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -390,10 +390,10 @@ def create_convert_map( "where.self": self._where, # tensor manipulation "argsort.default": self._argsort, - "cat.default": self._cat2, + "cat.default": self._cat, "chunk.default": self._chunk, "clamp.Tensor": self._clamp, - "concat.default": self._cat2, + "concat.default": self._cat, "copy_.default": self._copy_, "cumsum.default": self._cumsum, "cumprod.default": self._cumprod, diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index 27a827b5a66d..abc4cfe97f85 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -356,7 +356,7 @@ Optional> CheckConcatOutputShape2( } StructInfo InferStructInfoConcat2(const Call& call, const BlockBuilder& ctx) { - print("HERE!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!"); + printf("HERE!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!"); TensorStructInfo first_sinfo = GetInputTensorStructInfo(call, 0, ctx); Array tensor_sinfo = GetTensorStructInfoFromTuple(call, ctx, call->args[1]); diff --git a/tests/python/relax/test_from_exported_concat.py b/tests/python/relax/test_from_exported_concat.py index f006447f187c..cc64334a147c 100644 --- a/tests/python/relax/test_from_exported_concat.py +++ b/tests/python/relax/test_from_exported_concat.py @@ -80,7 +80,6 @@ def forward(self, x): raw_data = np.random.rand(2,3).astype("float32") assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) - assert 0 if __name__ == "__main__": From 1eba32bd2c0bdffe2df09452c5b63b165734500e Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Mon, 14 Apr 2025 02:33:27 -0400 Subject: [PATCH 070/105] concat2 works as a perfect copy of concat --- include/tvm/topi/transform.h | 59 ++++++++ .../torch/base_fx_graph_translator.py | 7 +- .../torch/exported_program_translator.py | 4 +- python/tvm/relax/op/manipulate.py | 24 ++- .../transform/legalize_ops/manipulate.py | 7 +- python/tvm/topi/transform.py | 7 +- src/relax/op/tensor/manipulate.cc | 143 ++++++++---------- src/topi/transform.cc | 2 +- 8 files changed, 152 insertions(+), 101 deletions(-) diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index 762148dcfac3..48b9234f40bf 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -519,6 +519,65 @@ inline Tensor concatenate(const Array& inputs, int axis = 0, std::string name, tag); } +/*! + * \brief Join a sequence of tensors along an existing axis + * + * \param inputs The input tensors + * \param axis The axis along which the tensors will be joined + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the concatenate operation + */ +inline Tensor concatenate2(const Array& inputs, int axis = 0, std::string name = "T_concat", + std::string tag = kInjective) { + int ndim = static_cast(inputs[0]->shape.size()); + ICHECK(-ndim <= axis && axis < ndim) << "concatenate only accepts `axis` in [-ndim, ndim)" + << ", but got axis = " << axis << ", and ndim = " << ndim; + if (axis < 0) { + axis += ndim; + } + ICHECK_LT(axis, inputs[0]->shape.size()) << "axis out of bounds"; + + Array axis_sizes; + for (auto t : inputs) { + axis_sizes.push_back(t->shape[axis]); + } + arith::Analyzer analyzer; + PrimExpr join_size = axis_sizes[0]; + for (size_t i = 1; i < axis_sizes.size(); ++i) { + join_size += axis_sizes[i]; + } + join_size = analyzer.Simplify(join_size); + Array out_shape; + for (size_t i = 0; i < inputs[0]->shape.size(); ++i) { + out_shape.push_back(i == static_cast(axis) ? join_size : inputs[0]->shape[i]); + } + + return compute( + out_shape, + [&](const Array& indices) { + auto ret = inputs[0](indices); + auto ind = indices[axis]; + for (size_t i = 0; i < inputs.size() - 1; ++i) { + ind -= axis_sizes[i]; + + Array idx; + for (size_t i = 0; i < static_cast(axis); ++i) { + idx.push_back(indices[i]); + } + idx.push_back(ind); + for (size_t i = axis + 1; i < indices.size(); ++i) { + idx.push_back(indices[i]); + } + + ret = tvm::if_then_else(ind >= 0, inputs[i + 1](idx), ret); + } + return ret; + }, + name, tag); +} + /*! * \brief Join a sequence of tensors along a new axis. * diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 340a243e0998..c2a82859248d 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -989,11 +989,8 @@ def _cat(self, node: fx.Node) -> relax.Var: def _cat2(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) - first_tensor = args[0][0] - other_tensors = args[0][1:] - print("base_fx_graph_translator: type(first_tensor)", type(first_tensor)) - print("base_fx_graph_translator: type(other_tensors)", type(other_tensors)) - return self.block_builder.emit(relax.op.concat2(first_tensor, other_tensors)) + axis = args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0) + return self.block_builder.emit(relax.op.concat2(args[0], axis=axis)) def _chunk(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index a5a76c77e921..c3d17e57b7ad 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -390,10 +390,10 @@ def create_convert_map( "where.self": self._where, # tensor manipulation "argsort.default": self._argsort, - "cat.default": self._cat, + "cat.default": self._cat2, "chunk.default": self._chunk, "clamp.Tensor": self._clamp, - "concat.default": self._cat, + "concat.default": self._cat2, "copy_.default": self._copy_, "cumsum.default": self._cumsum, "cumprod.default": self._cumprod, diff --git a/python/tvm/relax/op/manipulate.py b/python/tvm/relax/op/manipulate.py index 457cf24b8b31..fd61ea9c39eb 100644 --- a/python/tvm/relax/op/manipulate.py +++ b/python/tvm/relax/op/manipulate.py @@ -70,12 +70,28 @@ def concat(tensors: Union[Expr, List[Expr]], axis: Optional[int] = 0) -> Expr: tensors = RxTuple(tensors) return _ffi_api.concat(tensors, axis) # type: ignore -def concat2(first: Expr, tensors: Union[Expr, List[Expr]]) -> Expr: + +def concat2(tensors: Union[Expr, List[Expr]], axis: Optional[int] = 0) -> Expr: + """Concatenate the input tensors along the given axis. + + Parameters + ---------- + tensors : Union[relax.Expr, List[relax.Expr]] + An Expr in Tuple type, containing the tensors to be concatenated, + or a list of Tensors. + + axis : Optional[int] + The axis along which the tensors are concatenated. + If `axis` is `None`, the input tensor is required to be flattened before concatenation. + + Returns + ------- + result: relax.Expr + The concatenated tensor. + """ if isinstance(tensors, (list, tuple)): tensors = RxTuple(tensors) - print("manipulate.py: type(first)", type(first)) - print("manipulate.py: type(tensors)", type(tensors)) - return _ffi_api.concat2(first, tensors) # type: ignore + return _ffi_api.concat2(tensors, axis) # type: ignore def expand_dims(x: Expr, axis: Union[int, List[int]]) -> Expr: diff --git a/python/tvm/relax/transform/legalize_ops/manipulate.py b/python/tvm/relax/transform/legalize_ops/manipulate.py index abe16ef31151..83a28b0a61ca 100644 --- a/python/tvm/relax/transform/legalize_ops/manipulate.py +++ b/python/tvm/relax/transform/legalize_ops/manipulate.py @@ -72,12 +72,9 @@ def _concat(bb: BlockBuilder, call: Call) -> Expr: ) - @register_legalize("relax.concat2") def _concat2(bb: BlockBuilder, call: Call) -> Expr: - assert 0 - first = call.args[0] - t = call.args[1] + t = call.args[0] n_field = len(t.struct_info.fields) while isinstance(t, Var): binding = bb.lookup_binding(t) @@ -90,7 +87,7 @@ def _concat2(bb: BlockBuilder, call: Call) -> Expr: t.fields if isinstance(t, Tuple) else [bb.emit(TupleGetItem(t, i)) for i in range(n_field)] ) return bb.call_te( - topi.concatenate2, first, fields + topi.concatenate2, fields, None if call.attrs.axis is None else call.attrs.axis.value ) diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py index 7050ca1c5a66..eb61b4f0a018 100644 --- a/python/tvm/topi/transform.py +++ b/python/tvm/topi/transform.py @@ -402,7 +402,7 @@ def concatenate(a_tuple, axis=0): """ return cpp.concatenate(a_tuple, axis) -def concatenate2(first, a_tuple): +def concatenate2(a_tuple, axis=0): """Join a sequence of arrays along an existing axis. Parameters @@ -417,10 +417,7 @@ def concatenate2(first, a_tuple): ------- ret : tvm.te.Tensor """ - original_list = [first, *a_tuple] - return cpp.concatenate(original_list, 0) - - + return cpp.concatenate2(a_tuple, axis) def stack(a, axis): diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index abc4cfe97f85..64eeca610e36 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -331,35 +331,74 @@ TVM_REGISTER_OP("relax.concat") .set_attr("FPurity", Bool(true)); /* relax.concat2 */ -#include -Expr concat2(Expr first, Expr tensors) { - assert(0); +Expr concat2(Expr tensors, Optional axis) { + ObjectPtr attrs = make_object(); + attrs->axis = std::move(axis); + static const Op& op = Op::Get("relax.concat2"); - return Call(op, {std::move(first), std::move(tensors)}, Attrs(), {}); + return Call(op, {std::move(tensors)}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relax.op.concat2").set_body_typed(concat2); - -Optional> CheckConcatOutputShape2( - const Call& call, const BlockBuilder& ctx, const std::vector>& shape_values) { - assert(0); +TVM_REGISTER_GLOBAL("relax.op.concat2").set_body_typed(concat); +Optional> CheckConcatOutputShape2(const Call& call, const BlockBuilder& ctx, + const std::vector>& shape_values, + int axis) { + bool shape_unknown = false; + arith::Analyzer* analyzer = ctx->GetAnalyzer(); PrimExpr concat_sum = [&]() { - PrimExpr first_concat_dim = shape_values[0][0]; - return first_concat_dim * IntImm(DataType::Int(64), shape_values.size()); + // For the specified axis, we compute the sum of shape value over each tensor. + + // Special case, if all concatenated values have the same shape + StructuralEqual structural_equal; + PrimExpr first_concat_dim = shape_values[0][axis]; + bool all_same = std::all_of(shape_values.begin(), shape_values.end(), [&](const auto& a) { + return structural_equal(a[axis], first_concat_dim); + }); + if (all_same) { + return first_concat_dim * IntImm(DataType::Int(64), shape_values.size()); + } + + // General case, add up the dimensions along the specified axis. + PrimExpr concat_sum = IntImm(DataType::Int(64), 0); + for (Array shape_value : shape_values) { + concat_sum += shape_value[axis]; + } + return concat_sum; }(); + // For other axes, we check the equality of all tensors' shape values, to ensure safety. + for (int d = 0; d < static_cast(shape_values[0].size()); ++d) { + if (d == axis) { + continue; + } + for (int i = 1; i < static_cast(shape_values.size()); ++i) { + if (analyzer->CanProve(shape_values[i][d] != shape_values[0][d])) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Concat expects the input tensors to have the same shape on every " + "dimension except the one indicated by the input axis. However, the " + "input contains tensors whose shapes on dimension " + << d << " is " << shape_values[0][d] << " and " << shape_values[i][d]); + } else if (!analyzer->CanProveEqual(shape_values[i][d], shape_values[0][d])) { + shape_unknown = true; + } + } + } + + if (shape_unknown) { + return NullOpt; + } Array output_shape = shape_values[0]; - output_shape.Set(0, concat_sum); + output_shape.Set(axis, concat_sum); return output_shape; } StructInfo InferStructInfoConcat2(const Call& call, const BlockBuilder& ctx) { - printf("HERE!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!"); - TensorStructInfo first_sinfo = GetInputTensorStructInfo(call, 0, ctx); - Array tensor_sinfo = GetTensorStructInfoFromTuple(call, ctx, call->args[1]); - + if (call->args.size() != 1) { + ctx->ReportFatal(Diagnostic::Error(call) << "Concat op should have 1 argument"); + } + Array tensor_sinfo = GetTensorStructInfoFromTuple(call, ctx, call->args[0]); if (tensor_sinfo.empty()) { ctx->ReportFatal(Diagnostic::Error(call) << "Concat op expects at least one tensor in the input Tuple. However, the " @@ -376,62 +415,6 @@ StructInfo InferStructInfoConcat2(const Call& call, const BlockBuilder& ctx) { std::vector> shape_values; shape_values.reserve(tensor_sinfo.size()); - // First iteration with first_sinfo - if (first_sinfo->dtype.is_void()) { - is_void_dtype = true; - } else if (output_dtype.is_void()) { - output_dtype = first_sinfo->dtype; - } else if (first_sinfo->dtype != output_dtype) { - ctx->ReportFatal(Diagnostic::Error(call) - << "Concat expects all input tensors to have the same dtype. However, the " - "input contains tensors with dtype " - << output_dtype << " and " << first_sinfo->dtype); - } - - // Update the output ndim. - // Todo(relax-team): revisit here for better check on if the input tensor has - // ndim 1 when the input axis is undefined. - if (output_ndim == kUnknownNDim) { - output_ndim = first_sinfo->ndim; - } else if (first_sinfo->ndim != kUnknownNDim && first_sinfo->ndim != output_ndim) { - ctx->ReportFatal(Diagnostic::Error(call) - << "Concat expects all input tensors to have same ndim. However, the " - "input contains tensors with ndim " - << output_ndim << " and " << first_sinfo->ndim); - } - - // Update the virtual device. - if (!vdevice_unknown) { - if (first_sinfo->vdevice.defined()) { - if (!vdev.defined()) { - vdev = first_sinfo->vdevice.value(); - } else if (first_sinfo->vdevice.value()->target.defined()) { - // mismatch - if (first_sinfo->vdevice.value() != vdev) { - vdevice_unknown = true; - } - } - } - } - - // Update the shape values for best effort check. - const auto* shape_expr = first_sinfo->shape.as(); - if (shape_expr != nullptr) { - shape_values.push_back(shape_expr->values); - } else { - shape_unknown = true; - - if (!first_sinfo->shape.defined()) { - } else { - // Keep the shape value for equality check. - ShapeStructInfo shape_sinfo = - Downcast(first_sinfo->shape.value()->struct_info_); - if (shape_sinfo->values.defined()) { - shape_values.push_back(shape_sinfo->values.value()); - } - } - } - for (TensorStructInfo sinfo : tensor_sinfo) { // Update the output dtype. if (sinfo->dtype.is_void()) { @@ -501,6 +484,8 @@ StructInfo InferStructInfoConcat2(const Call& call, const BlockBuilder& ctx) { : TensorStructInfo(output_dtype, output_ndim, vdev); } + int axis = + attrs->axis.defined() ? NormalizeAxis(call, ctx, output_ndim, attrs->axis.value()->value) : 0; // If there is only one input tensor, no action is needed. if (tensor_sinfo.size() == 1) { return tensor_sinfo[0]; @@ -513,7 +498,7 @@ StructInfo InferStructInfoConcat2(const Call& call, const BlockBuilder& ctx) { } // As long as the there is known shape value, we will do the best effort check to ensure safety. - Optional> output_shape = CheckConcatOutputShape2(call, ctx, shape_values); + Optional> output_shape = CheckConcatOutputShape2(call, ctx, shape_values, axis); if (shape_unknown || !output_shape.defined()) { if (!vdevice_unknown) { @@ -529,13 +514,13 @@ StructInfo InferStructInfoConcat2(const Call& call, const BlockBuilder& ctx) { } InferLayoutOutput InferLayoutConcat2(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { + const Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { ICHECK(NoDesiredLayout(call, desired_layouts)); const auto* attrs = call->attrs.as(); ICHECK(attrs != nullptr) << "Invalid Call"; - NLayout nlayout = GetNLayout(var_layout_map, call->args[1]); + NLayout nlayout = GetNLayout(var_layout_map, call->args[0]); ICHECK(nlayout.IsNested()); ICHECK(nlayout.NestedArray()[0].IsLeaf()); @@ -552,15 +537,15 @@ InferLayoutOutput InferLayoutConcat2(const Call& call, } TVM_REGISTER_OP("relax.concat2") - // .set_attrs_type() - // .set_num_inputs(1) - .add_argument("first", "Tensor", "The first tensor.") + .set_attrs_type() + .set_num_inputs(1) .add_argument("tensors", "Tuple of Tensors", "The input list of tensors.") .set_attr("FInferStructInfo", InferStructInfoConcat2) .set_attr("FRelaxInferLayout", InferLayoutConcat2) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", Bool(true)); + /* relax.expand_dims */ TVM_REGISTER_NODE_TYPE(ExpandDimsAttrs); diff --git a/src/topi/transform.cc b/src/topi/transform.cc index 96be2a9bb56f..19192cca6ad1 100644 --- a/src/topi/transform.cc +++ b/src/topi/transform.cc @@ -71,7 +71,7 @@ TVM_REGISTER_GLOBAL("topi.concatenate").set_body([](TVMArgs args, TVMRetValue* r }); TVM_REGISTER_GLOBAL("topi.concatenate2").set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = concatenate(args[0], args[1]); + *rv = concatenate2(args[0], args[1]); }); TVM_REGISTER_GLOBAL("topi.stack").set_body([](TVMArgs args, TVMRetValue* rv) { From 4e096d1e63dd1ca1a5a7689fb77d19cfeba7340c Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Mon, 14 Apr 2025 02:54:47 -0400 Subject: [PATCH 071/105] concat2 works! now need to strip as much as possible, and then convert into indexTensor --- include/tvm/topi/transform.h | 4 ++-- .../frontend/torch/base_fx_graph_translator.py | 2 +- python/tvm/relax/op/manipulate.py | 5 +++-- .../relax/transform/legalize_ops/manipulate.py | 6 ++---- python/tvm/topi/transform.py | 4 ++-- src/relax/op/tensor/manipulate.cc | 16 +++++++++------- src/topi/transform.cc | 2 +- 7 files changed, 20 insertions(+), 19 deletions(-) diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index 48b9234f40bf..51bddd132834 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -529,8 +529,8 @@ inline Tensor concatenate(const Array& inputs, int axis = 0, std::string * * \return A Tensor whose op member is the concatenate operation */ -inline Tensor concatenate2(const Array& inputs, int axis = 0, std::string name = "T_concat", - std::string tag = kInjective) { +inline Tensor concatenate2(const Tensor& first, const Array& inputs, int axis = 0, + std::string name = "T_concat", std::string tag = kInjective) { int ndim = static_cast(inputs[0]->shape.size()); ICHECK(-ndim <= axis && axis < ndim) << "concatenate only accepts `axis` in [-ndim, ndim)" << ", but got axis = " << axis << ", and ndim = " << ndim; diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index c2a82859248d..e50433a9b6ac 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -990,7 +990,7 @@ def _cat(self, node: fx.Node) -> relax.Var: def _cat2(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) axis = args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0) - return self.block_builder.emit(relax.op.concat2(args[0], axis=axis)) + return self.block_builder.emit(relax.op.concat2(args[0][0], args[0], axis=axis)) def _chunk(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] diff --git a/python/tvm/relax/op/manipulate.py b/python/tvm/relax/op/manipulate.py index fd61ea9c39eb..9a3d5e8e7760 100644 --- a/python/tvm/relax/op/manipulate.py +++ b/python/tvm/relax/op/manipulate.py @@ -71,7 +71,7 @@ def concat(tensors: Union[Expr, List[Expr]], axis: Optional[int] = 0) -> Expr: return _ffi_api.concat(tensors, axis) # type: ignore -def concat2(tensors: Union[Expr, List[Expr]], axis: Optional[int] = 0) -> Expr: +def concat2(first:Expr, tensors: Union[Expr, List[Expr]], axis: Optional[int] = 0) -> Expr: """Concatenate the input tensors along the given axis. Parameters @@ -91,7 +91,8 @@ def concat2(tensors: Union[Expr, List[Expr]], axis: Optional[int] = 0) -> Expr: """ if isinstance(tensors, (list, tuple)): tensors = RxTuple(tensors) - return _ffi_api.concat2(tensors, axis) # type: ignore + # return _ffi_api.concat2(tensors, axis) # TODO this works for some reason! + return _ffi_api.concat2(first, tensors, axis) # type: ignore def expand_dims(x: Expr, axis: Union[int, List[int]]) -> Expr: diff --git a/python/tvm/relax/transform/legalize_ops/manipulate.py b/python/tvm/relax/transform/legalize_ops/manipulate.py index 83a28b0a61ca..6dc8e8283b4a 100644 --- a/python/tvm/relax/transform/legalize_ops/manipulate.py +++ b/python/tvm/relax/transform/legalize_ops/manipulate.py @@ -74,7 +74,7 @@ def _concat(bb: BlockBuilder, call: Call) -> Expr: @register_legalize("relax.concat2") def _concat2(bb: BlockBuilder, call: Call) -> Expr: - t = call.args[0] + t = call.args[1] n_field = len(t.struct_info.fields) while isinstance(t, Var): binding = bb.lookup_binding(t) @@ -87,7 +87,7 @@ def _concat2(bb: BlockBuilder, call: Call) -> Expr: t.fields if isinstance(t, Tuple) else [bb.emit(TupleGetItem(t, i)) for i in range(n_field)] ) return bb.call_te( - topi.concatenate2, fields, None if call.attrs.axis is None else call.attrs.axis.value + topi.concatenate2, call.args[0], fields, None if call.attrs.axis is None else call.attrs.axis.value ) @@ -206,8 +206,6 @@ def _index_tensor(bb: BlockBuilder, call: Call) -> Expr: ) - - @register_legalize("relax.scatter_elements") def _scatter_elements(bb: BlockBuilder, call: Call) -> Expr: return bb.call_te( diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py index eb61b4f0a018..f49556599611 100644 --- a/python/tvm/topi/transform.py +++ b/python/tvm/topi/transform.py @@ -402,7 +402,7 @@ def concatenate(a_tuple, axis=0): """ return cpp.concatenate(a_tuple, axis) -def concatenate2(a_tuple, axis=0): +def concatenate2(first, a_tuple, axis=0): """Join a sequence of arrays along an existing axis. Parameters @@ -417,7 +417,7 @@ def concatenate2(a_tuple, axis=0): ------- ret : tvm.te.Tensor """ - return cpp.concatenate2(a_tuple, axis) + return cpp.concatenate2(first, a_tuple, axis) def stack(a, axis): diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index 64eeca610e36..5b08c41426dc 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -332,15 +332,16 @@ TVM_REGISTER_OP("relax.concat") /* relax.concat2 */ -Expr concat2(Expr tensors, Optional axis) { +Expr concat2(Expr first, Expr tensors, Optional axis) { ObjectPtr attrs = make_object(); attrs->axis = std::move(axis); static const Op& op = Op::Get("relax.concat2"); - return Call(op, {std::move(tensors)}, Attrs(attrs), {}); + return Call(op, {std::move(first), std::move(tensors)}, Attrs(attrs), {}); + } -TVM_REGISTER_GLOBAL("relax.op.concat2").set_body_typed(concat); +TVM_REGISTER_GLOBAL("relax.op.concat2").set_body_typed(concat2); Optional> CheckConcatOutputShape2(const Call& call, const BlockBuilder& ctx, const std::vector>& shape_values, @@ -352,7 +353,7 @@ Optional> CheckConcatOutputShape2(const Call& call, const BlockB // Special case, if all concatenated values have the same shape StructuralEqual structural_equal; - PrimExpr first_concat_dim = shape_values[0][axis]; + PrimExpr first_concat_dim = shape_values[1][axis]; bool all_same = std::all_of(shape_values.begin(), shape_values.end(), [&](const auto& a) { return structural_equal(a[axis], first_concat_dim); }); @@ -395,10 +396,10 @@ Optional> CheckConcatOutputShape2(const Call& call, const BlockB } StructInfo InferStructInfoConcat2(const Call& call, const BlockBuilder& ctx) { - if (call->args.size() != 1) { + if (call->args.size() != 2) { ctx->ReportFatal(Diagnostic::Error(call) << "Concat op should have 1 argument"); } - Array tensor_sinfo = GetTensorStructInfoFromTuple(call, ctx, call->args[0]); + Array tensor_sinfo = GetTensorStructInfoFromTuple(call, ctx, call->args[1]); if (tensor_sinfo.empty()) { ctx->ReportFatal(Diagnostic::Error(call) << "Concat op expects at least one tensor in the input Tuple. However, the " @@ -538,7 +539,8 @@ InferLayoutOutput InferLayoutConcat2(const Call& call, TVM_REGISTER_OP("relax.concat2") .set_attrs_type() - .set_num_inputs(1) + .set_num_inputs(2) + .add_argument("first", "Tensor", "The first tensor") .add_argument("tensors", "Tuple of Tensors", "The input list of tensors.") .set_attr("FInferStructInfo", InferStructInfoConcat2) .set_attr("FRelaxInferLayout", InferLayoutConcat2) diff --git a/src/topi/transform.cc b/src/topi/transform.cc index 19192cca6ad1..426d8740ef78 100644 --- a/src/topi/transform.cc +++ b/src/topi/transform.cc @@ -71,7 +71,7 @@ TVM_REGISTER_GLOBAL("topi.concatenate").set_body([](TVMArgs args, TVMRetValue* r }); TVM_REGISTER_GLOBAL("topi.concatenate2").set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = concatenate2(args[0], args[1]); + *rv = concatenate2(args[0], args[1], args[2]); }); TVM_REGISTER_GLOBAL("topi.stack").set_body([](TVMArgs args, TVMRetValue* rv) { From 8adbc6180da6a342daea17fa08a6e550ec58f99a Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Mon, 14 Apr 2025 02:56:07 -0400 Subject: [PATCH 072/105] concat2 still passes --- src/relax/op/tensor/manipulate.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index 5b08c41426dc..cbf733586ad6 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -544,7 +544,7 @@ TVM_REGISTER_OP("relax.concat2") .add_argument("tensors", "Tuple of Tensors", "The input list of tensors.") .set_attr("FInferStructInfo", InferStructInfoConcat2) .set_attr("FRelaxInferLayout", InferLayoutConcat2) - .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) + // .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", Bool(true)); From 1cfa0247a607bc3e988d98e6d644ff9aa20e238c Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Mon, 14 Apr 2025 02:56:25 -0400 Subject: [PATCH 073/105] concat2 still passes --- src/relax/op/tensor/manipulate.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index cbf733586ad6..ffc73985eeb9 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -543,7 +543,7 @@ TVM_REGISTER_OP("relax.concat2") .add_argument("first", "Tensor", "The first tensor") .add_argument("tensors", "Tuple of Tensors", "The input list of tensors.") .set_attr("FInferStructInfo", InferStructInfoConcat2) - .set_attr("FRelaxInferLayout", InferLayoutConcat2) + // .set_attr("FRelaxInferLayout", InferLayoutConcat2) // .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", Bool(true)); From f0d533bb9086b224d42b955703a6956b9198f5d4 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Mon, 14 Apr 2025 03:02:48 -0400 Subject: [PATCH 074/105] still works whne grabing first tensro --- src/relax/op/tensor/manipulate.cc | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index ffc73985eeb9..54b18ada2806 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -338,14 +338,13 @@ Expr concat2(Expr first, Expr tensors, Optional axis) { static const Op& op = Op::Get("relax.concat2"); return Call(op, {std::move(first), std::move(tensors)}, Attrs(attrs), {}); - } TVM_REGISTER_GLOBAL("relax.op.concat2").set_body_typed(concat2); Optional> CheckConcatOutputShape2(const Call& call, const BlockBuilder& ctx, - const std::vector>& shape_values, - int axis) { + const std::vector>& shape_values, + int axis) { bool shape_unknown = false; arith::Analyzer* analyzer = ctx->GetAnalyzer(); PrimExpr concat_sum = [&]() { @@ -399,6 +398,8 @@ StructInfo InferStructInfoConcat2(const Call& call, const BlockBuilder& ctx) { if (call->args.size() != 2) { ctx->ReportFatal(Diagnostic::Error(call) << "Concat op should have 1 argument"); } + TensorStructInfo data_sinfo = GetInputTensorStructInfo(call, 0, ctx); + Array tensor_sinfo = GetTensorStructInfoFromTuple(call, ctx, call->args[1]); if (tensor_sinfo.empty()) { ctx->ReportFatal(Diagnostic::Error(call) @@ -515,8 +516,8 @@ StructInfo InferStructInfoConcat2(const Call& call, const BlockBuilder& ctx) { } InferLayoutOutput InferLayoutConcat2(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { + const Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { ICHECK(NoDesiredLayout(call, desired_layouts)); const auto* attrs = call->attrs.as(); @@ -542,12 +543,11 @@ TVM_REGISTER_OP("relax.concat2") .set_num_inputs(2) .add_argument("first", "Tensor", "The first tensor") .add_argument("tensors", "Tuple of Tensors", "The input list of tensors.") - .set_attr("FInferStructInfo", InferStructInfoConcat2) + .set_attr("FInferStructInfo", InferStructInfoConcat2) // TODO necessary // .set_attr("FRelaxInferLayout", InferLayoutConcat2) // .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", Bool(true)); - /* relax.expand_dims */ TVM_REGISTER_NODE_TYPE(ExpandDimsAttrs); From d22b59891764ad954ded5dffea65cefc05bf981b Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Mon, 14 Apr 2025 03:05:21 -0400 Subject: [PATCH 075/105] still works whne grabing first tensro --- src/relax/op/tensor/manipulate.cc | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index 54b18ada2806..2afc3fff5c39 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -409,7 +409,9 @@ StructInfo InferStructInfoConcat2(const Call& call, const BlockBuilder& ctx) { const auto* attrs = call->attrs.as(); int output_ndim = attrs->axis.defined() ? kUnknownNDim : 1; - DataType output_dtype = DataType::Void(); + DataType output_dtype = data_sinfo->dtype; + + Optional vdev = NullOpt; bool shape_unknown = false; bool is_void_dtype = false; @@ -421,14 +423,7 @@ StructInfo InferStructInfoConcat2(const Call& call, const BlockBuilder& ctx) { // Update the output dtype. if (sinfo->dtype.is_void()) { is_void_dtype = true; - } else if (output_dtype.is_void()) { - output_dtype = sinfo->dtype; - } else if (sinfo->dtype != output_dtype) { - ctx->ReportFatal(Diagnostic::Error(call) - << "Concat expects all input tensors to have the same dtype. However, the " - "input contains tensors with dtype " - << output_dtype << " and " << sinfo->dtype); - } + } // Update the output ndim. // Todo(relax-team): revisit here for better check on if the input tensor has From 636ab70db59328ae76f1e6c12bcb877c13d139ed Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Mon, 14 Apr 2025 03:10:34 -0400 Subject: [PATCH 076/105] still works whne grabing first tensro --- src/relax/op/tensor/manipulate.cc | 52 +++++-------------------------- 1 file changed, 8 insertions(+), 44 deletions(-) diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index 2afc3fff5c39..76be6d65ef62 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -409,13 +409,17 @@ StructInfo InferStructInfoConcat2(const Call& call, const BlockBuilder& ctx) { const auto* attrs = call->attrs.as(); int output_ndim = attrs->axis.defined() ? kUnknownNDim : 1; - DataType output_dtype = data_sinfo->dtype; - + DataType output_dtype = data_sinfo->dtype; + bool vdevice_unknown = false; Optional vdev = NullOpt; + if (data_sinfo->vdevice.defined()) { + vdev = data_sinfo->vdevice.value(); + vdevice_unknown = true; + } + bool shape_unknown = false; bool is_void_dtype = false; - bool vdevice_unknown = false; std::vector> shape_values; shape_values.reserve(tensor_sinfo.size()); @@ -423,7 +427,7 @@ StructInfo InferStructInfoConcat2(const Call& call, const BlockBuilder& ctx) { // Update the output dtype. if (sinfo->dtype.is_void()) { is_void_dtype = true; - } + } // Update the output ndim. // Todo(relax-team): revisit here for better check on if the input tensor has @@ -437,20 +441,6 @@ StructInfo InferStructInfoConcat2(const Call& call, const BlockBuilder& ctx) { << output_ndim << " and " << sinfo->ndim); } - // Update the virtual device. - if (!vdevice_unknown) { - if (sinfo->vdevice.defined()) { - if (!vdev.defined()) { - vdev = sinfo->vdevice.value(); - } else if (sinfo->vdevice.value()->target.defined()) { - // mismatch - if (sinfo->vdevice.value() != vdev) { - vdevice_unknown = true; - } - } - } - } - // Update the shape values for best effort check. const auto* shape_expr = sinfo->shape.as(); if (shape_expr != nullptr) { @@ -472,9 +462,6 @@ StructInfo InferStructInfoConcat2(const Call& call, const BlockBuilder& ctx) { if (is_void_dtype) { output_dtype = DataType::Void(); } - if (vdevice_unknown) { - vdev = NullOpt; - } if (output_ndim == kUnknownNDim) { return tensor_sinfo.size() == 1 ? tensor_sinfo[0] @@ -510,29 +497,6 @@ StructInfo InferStructInfoConcat2(const Call& call, const BlockBuilder& ctx) { } } -InferLayoutOutput InferLayoutConcat2(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { - ICHECK(NoDesiredLayout(call, desired_layouts)); - - const auto* attrs = call->attrs.as(); - ICHECK(attrs != nullptr) << "Invalid Call"; - NLayout nlayout = GetNLayout(var_layout_map, call->args[0]); - ICHECK(nlayout.IsNested()); - ICHECK(nlayout.NestedArray()[0].IsLeaf()); - - int n_tensor = nlayout.NestedArray().size(); - LayoutDecision layout = nlayout.NestedArray()[0].LeafValue(); - Array input_layouts, output_layouts; - for (int i = 0; i < n_tensor; ++i) { - input_layouts.push_back(layout); - } - output_layouts.push_back(layout); - ObjectPtr new_attrs = make_object(*attrs); - new_attrs->axis = Integer(FindAxis(layout->layout, attrs->axis.value_or(0)->value)); - return InferLayoutOutput({NLayout(input_layouts)}, output_layouts, Attrs(new_attrs)); -} - TVM_REGISTER_OP("relax.concat2") .set_attrs_type() .set_num_inputs(2) From cd0f013fcce8a7368e9d401ec6ba15434eb8af5a Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Mon, 14 Apr 2025 03:17:54 -0400 Subject: [PATCH 077/105] adding back struct info makes test pass --- src/relax/op/tensor/manipulate.cc | 144 ++++++++++++++++-------------- 1 file changed, 75 insertions(+), 69 deletions(-) diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index 76be6d65ef62..52838a13d9b3 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -420,81 +420,87 @@ StructInfo InferStructInfoConcat2(const Call& call, const BlockBuilder& ctx) { bool shape_unknown = false; bool is_void_dtype = false; - std::vector> shape_values; - shape_values.reserve(tensor_sinfo.size()); - - for (TensorStructInfo sinfo : tensor_sinfo) { - // Update the output dtype. - if (sinfo->dtype.is_void()) { - is_void_dtype = true; - } - - // Update the output ndim. - // Todo(relax-team): revisit here for better check on if the input tensor has - // ndim 1 when the input axis is undefined. - if (output_ndim == kUnknownNDim) { - output_ndim = sinfo->ndim; - } else if (sinfo->ndim != kUnknownNDim && sinfo->ndim != output_ndim) { - ctx->ReportFatal(Diagnostic::Error(call) - << "Concat expects all input tensors to have same ndim. However, the " - "input contains tensors with ndim " - << output_ndim << " and " << sinfo->ndim); - } - - // Update the shape values for best effort check. - const auto* shape_expr = sinfo->shape.as(); - if (shape_expr != nullptr) { - shape_values.push_back(shape_expr->values); - continue; - } - shape_unknown = true; - - if (!sinfo->shape.defined()) { - continue; - } - // Keep the shape value for equality check. - ShapeStructInfo shape_sinfo = Downcast(sinfo->shape.value()->struct_info_); - if (shape_sinfo->values.defined()) { - shape_values.push_back(shape_sinfo->values.value()); - } + if (data_sinfo->dtype.is_void()) { + is_void_dtype = true; } - if (is_void_dtype) { - output_dtype = DataType::Void(); - } + std::vector> shape_values; + shape_values.reserve(tensor_sinfo.size()); - if (output_ndim == kUnknownNDim) { - return tensor_sinfo.size() == 1 ? tensor_sinfo[0] - : TensorStructInfo(output_dtype, output_ndim, vdev); - } - int axis = - attrs->axis.defined() ? NormalizeAxis(call, ctx, output_ndim, attrs->axis.value()->value) : 0; - // If there is only one input tensor, no action is needed. - if (tensor_sinfo.size() == 1) { - return tensor_sinfo[0]; - } - if (shape_values.empty()) { - if (!vdevice_unknown) { - return TensorStructInfo(output_dtype, output_ndim, vdev); - } - return TensorStructInfo(output_dtype, output_ndim); - } + // TensorStructInfo data_sinfo = GetInputTensorStructInfo(call, 0, ctx); + // DataType output_dtype = data_sinfo->dtype; - // As long as the there is known shape value, we will do the best effort check to ensure safety. - Optional> output_shape = CheckConcatOutputShape2(call, ctx, shape_values, axis); + return TensorStructInfo(output_dtype, kUnknownNDim); - if (shape_unknown || !output_shape.defined()) { - if (!vdevice_unknown) { - return TensorStructInfo(output_dtype, output_ndim, vdev); - } - return TensorStructInfo(output_dtype, output_ndim); - } else { - if (!vdevice_unknown) { - return TensorStructInfo(ShapeExpr(output_shape.value()), output_dtype, vdev); - } - return TensorStructInfo(ShapeExpr(output_shape.value()), output_dtype); - } + // for (TensorStructInfo sinfo : tensor_sinfo) { + + // // Update the output ndim. + // // Todo(relax-team): revisit here for better check on if the input tensor has + // // ndim 1 when the input axis is undefined. + // if (output_ndim == kUnknownNDim) { + // output_ndim = sinfo->ndim; + // } else if (sinfo->ndim != kUnknownNDim && sinfo->ndim != output_ndim) { + // ctx->ReportFatal(Diagnostic::Error(call) + // << "Concat expects all input tensors to have same ndim. However, the " + // "input contains tensors with ndim " + // << output_ndim << " and " << sinfo->ndim); + // } + + // // Update the shape values for best effort check. + // const auto* shape_expr = sinfo->shape.as(); + // if (shape_expr != nullptr) { + // shape_values.push_back(shape_expr->values); + // continue; + // } + // shape_unknown = true; + + // if (!sinfo->shape.defined()) { + // continue; + // } + // // Keep the shape value for equality check. + // ShapeStructInfo shape_sinfo = Downcast(sinfo->shape.value()->struct_info_); + // if (shape_sinfo->values.defined()) { + // shape_values.push_back(shape_sinfo->values.value()); + // } + // } + + // if (is_void_dtype) { + // output_dtype = DataType::Void(); + // } + + // if (output_ndim == kUnknownNDim) { + // return tensor_sinfo.size() == 1 ? tensor_sinfo[0] + // : TensorStructInfo(output_dtype, output_ndim, vdev); + // } + + // int axis = + // attrs->axis.defined() ? NormalizeAxis(call, ctx, output_ndim, attrs->axis.value()->value) : 0; + // // If there is only one input tensor, no action is needed. + // if (tensor_sinfo.size() == 1) { + // return tensor_sinfo[0]; + // } + // if (shape_values.empty()) { + // if (!vdevice_unknown) { + // return TensorStructInfo(output_dtype, output_ndim, vdev); + // } + // return TensorStructInfo(output_dtype, output_ndim); + // } + + // // As long as the there is known shape value, we will do the best effort check to ensure safety. + // Optional> output_shape = CheckConcatOutputShape2(call, ctx, shape_values, axis); + + // if (shape_unknown || !output_shape.defined()) { + // if (!vdevice_unknown) { + // return TensorStructInfo(output_dtype, output_ndim, vdev); + // } + // return TensorStructInfo(output_dtype, output_ndim); + // } else { + // if (!vdevice_unknown) { + // return TensorStructInfo(ShapeExpr(output_shape.value()), output_dtype, vdev); + // } + // return TensorStructInfo(ShapeExpr(output_shape.value()), output_dtype); + // } } TVM_REGISTER_OP("relax.concat2") From a06db3d485d23776d6a797d6f1c0c5b76dd6d1eb Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Mon, 14 Apr 2025 03:31:50 -0400 Subject: [PATCH 078/105] concat2 passes with very simple manipulate.ccgit status! --- src/relax/op/tensor/manipulate.cc | 33 +++++++++++++++++++++++-------- 1 file changed, 25 insertions(+), 8 deletions(-) diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index 52838a13d9b3..ba2cf272f71f 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -424,16 +424,34 @@ StructInfo InferStructInfoConcat2(const Call& call, const BlockBuilder& ctx) { is_void_dtype = true; } - std::vector> shape_values; - shape_values.reserve(tensor_sinfo.size()); + TensorStructInfo indices_sinfo = data_sinfo; + Optional> data_shape_value; + if (data_sinfo->shape.defined()) { + data_shape_value = GetStructInfoAs(data_sinfo->shape.value())->values; + } + Optional> indices_shape_value; + if (indices_sinfo->shape.defined()) { + indices_shape_value = + GetStructInfoAs(indices_sinfo->shape.value())->values; + } - // TensorStructInfo data_sinfo = GetInputTensorStructInfo(call, 0, ctx); - // DataType output_dtype = data_sinfo->dtype; + if (indices_sinfo->shape.defined()) { + return TensorStructInfo(indices_sinfo->shape.value(), output_dtype, indices_sinfo->vdevice); + } else { + return TensorStructInfo(output_dtype, indices_sinfo->ndim, indices_sinfo->vdevice); + } - return TensorStructInfo(output_dtype, kUnknownNDim); + // std::vector> shape_values; + // shape_values.reserve(tensor_sinfo.size()); + + + // // TensorStructInfo data_sinfo = GetInputTensorStructInfo(call, 0, ctx); + // // DataType output_dtype = data_sinfo->dtype; + + // // return TensorStructInfo(output_dtype, kUnknownNDim); - // for (TensorStructInfo sinfo : tensor_sinfo) { + // for (TensorStructInfo sinfo : tensor_sinfo) { // TODO need this for-loop! // // Update the output ndim. // // Todo(relax-team): revisit here for better check on if the input tensor has @@ -487,6 +505,7 @@ StructInfo InferStructInfoConcat2(const Call& call, const BlockBuilder& ctx) { // return TensorStructInfo(output_dtype, output_ndim); // } + // // TODO why do I need to output_shape here?? index_tensor did not require it! // // As long as the there is known shape value, we will do the best effort check to ensure safety. // Optional> output_shape = CheckConcatOutputShape2(call, ctx, shape_values, axis); @@ -509,8 +528,6 @@ TVM_REGISTER_OP("relax.concat2") .add_argument("first", "Tensor", "The first tensor") .add_argument("tensors", "Tuple of Tensors", "The input list of tensors.") .set_attr("FInferStructInfo", InferStructInfoConcat2) // TODO necessary - // .set_attr("FRelaxInferLayout", InferLayoutConcat2) - // .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", Bool(true)); /* relax.expand_dims */ From aa08975aa02abab717faf6d0a9c7cfb19933fe8e Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Mon, 14 Apr 2025 03:36:24 -0400 Subject: [PATCH 079/105] I pass old test2! --- .../torch/base_fx_graph_translator.py | 5 ++-- python/tvm/topi/transform.py | 3 ++- tests/python/relax/test_from_exported_OLD.py | 23 +++++++++---------- 3 files changed, 16 insertions(+), 15 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index e50433a9b6ac..f1529e04f281 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1107,8 +1107,9 @@ def _index_tensor(self, node: fx.Node) -> relax.Var: # index = self.env[node.args[1]] # TODO # return self.block_builder.emit(relax.op.index_tensor(args[0], indices)) # TODO revert! removed to test collapse sum like - return self.block_builder.emit(relax.op.index_tensor(args[0], indices)) - # return self.block_builder.emit(relax.op.index_tensor(args[0], indices)) # TODO switch the above to this + # return self.block_builder.emit(relax.op.index_tensor(args[0], indices)) + return self.block_builder.emit(relax.op.concat2(args[0], indices)) + def _permute(self, node: fx.Node) -> relax.Var: import torch # type: ignore diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py index f49556599611..8f773da47bf5 100644 --- a/python/tvm/topi/transform.py +++ b/python/tvm/topi/transform.py @@ -417,7 +417,8 @@ def concatenate2(first, a_tuple, axis=0): ------- ret : tvm.te.Tensor """ - return cpp.concatenate2(first, a_tuple, axis) + # return cpp.concatenate2(first, a_tuple, axis) + return topi.adv_index(first, a_tuple) def stack(a, axis): diff --git a/tests/python/relax/test_from_exported_OLD.py b/tests/python/relax/test_from_exported_OLD.py index 249e0d6b86a6..3d8016e38bae 100644 --- a/tests/python/relax/test_from_exported_OLD.py +++ b/tests/python/relax/test_from_exported_OLD.py @@ -114,21 +114,20 @@ def assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, tar # assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) -# @tvm.testing.parametrize_targets("cuda") -# def test_index_tensor2(target, dev): -# class IndexTensorModel2(nn.Module): -# def __init__(self): -# super().__init__() - -# def forward(self, x): -# return x[torch.tensor([0,2])] +@tvm.testing.parametrize_targets("cuda") +def test_index_tensor2(target, dev): + class IndexTensorModel2(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x[torch.tensor([0,2])] -# torch_module = IndexTensorModel2().eval() + torch_module = IndexTensorModel2().eval() -# raw_data = np.random.rand(3,4).astype("float32") + raw_data = np.random.rand(3,4).astype("float32") -# assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) -# assert 0 + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) # @tvm.testing.parametrize_targets("cuda") From f23b7864ac45685082f214962c5286a96a1725f9 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Mon, 14 Apr 2025 03:37:32 -0400 Subject: [PATCH 080/105] 7 tests pass! --- tests/python/relax/test_from_exported_OLD.py | 122 +++++++++---------- 1 file changed, 61 insertions(+), 61 deletions(-) diff --git a/tests/python/relax/test_from_exported_OLD.py b/tests/python/relax/test_from_exported_OLD.py index 3d8016e38bae..df67072f7851 100644 --- a/tests/python/relax/test_from_exported_OLD.py +++ b/tests/python/relax/test_from_exported_OLD.py @@ -64,54 +64,54 @@ def assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, tar -# @tvm.testing.parametrize_targets("cuda") -# def test_full(target, dev): -# class FullModel(nn.Module): -# def __init__(self): -# super().__init__() - -# def forward(self, x): -# return torch.full((2, 3), 3.141592) +@tvm.testing.parametrize_targets("cuda") +def test_full(target, dev): + class FullModel(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.full((2, 3), 3.141592) -# torch_module = FullModel().eval() + torch_module = FullModel().eval() -# raw_data = np.random.rand(3,3).astype("float32") + raw_data = np.random.rand(3,3).astype("float32") -# assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) # Test index.Tensor # TODO aggregate into one big tet -# @tvm.testing.parametrize_targets("cuda") -# def test_index_tensor0(target, dev): -# class IndexModel0(nn.Module): -# def __init__(self): -# super().__init__() +@tvm.testing.parametrize_targets("cuda") +def test_index_tensor0(target, dev): + class IndexModel0(nn.Module): + def __init__(self): + super().__init__() -# def forward(self, x): -# return x[torch.tensor([0])] + def forward(self, x): + return x[torch.tensor([0])] -# torch_module = IndexModel0().eval() + torch_module = IndexModel0().eval() -# raw_data = np.random.rand(3,3).astype("float32") + raw_data = np.random.rand(3,3).astype("float32") -# assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) -# @tvm.testing.parametrize_targets("cuda") -# def test_index_tensor1(target, dev): -# class IndexModel1(nn.Module): -# def __init__(self): -# super().__init__() +@tvm.testing.parametrize_targets("cuda") +def test_index_tensor1(target, dev): + class IndexModel1(nn.Module): + def __init__(self): + super().__init__() -# def forward(self, x): -# return x[torch.tensor([[0]])] + def forward(self, x): + return x[torch.tensor([[0]])] -# torch_module = IndexModel1().eval() + torch_module = IndexModel1().eval() -# raw_data = np.random.rand(2,3).astype("float32") + raw_data = np.random.rand(2,3).astype("float32") -# assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) @tvm.testing.parametrize_targets("cuda") @@ -130,46 +130,46 @@ def forward(self, x): assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) -# @tvm.testing.parametrize_targets("cuda") -# def test_index_tensor3(target, dev): -# class IndexTensorModel3(nn.Module): -# def __init__(self): -# super().__init__() +@tvm.testing.parametrize_targets("cuda") +def test_index_tensor3(target, dev): + class IndexTensorModel3(nn.Module): + def __init__(self): + super().__init__() -# def forward(self, x): -# return x[[[[0,2],[1,3]]]] + def forward(self, x): + return x[[[[0,2],[1,3]]]] -# torch_module = IndexTensorModel3().eval() -# raw_data = np.random.rand(5,5,5).astype("float32") -# assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + torch_module = IndexTensorModel3().eval() + raw_data = np.random.rand(5,5,5).astype("float32") + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) -# @tvm.testing.parametrize_targets("cuda") -# def test_index_tensor4(target, dev): -# class IndexTensorModel4(nn.Module): -# def __init__(self): -# super().__init__() +@tvm.testing.parametrize_targets("cuda") +def test_index_tensor4(target, dev): + class IndexTensorModel4(nn.Module): + def __init__(self): + super().__init__() -# def forward(self, x): -# return x[[[1,4]]] + def forward(self, x): + return x[[[1,4]]] -# torch_module = IndexTensorModel4().eval() -# raw_data = np.random.rand(5,5,5).astype("float32") -# assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + torch_module = IndexTensorModel4().eval() + raw_data = np.random.rand(5,5,5).astype("float32") + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) -# @tvm.testing.parametrize_targets("cuda") -# def test_index_tensor5(target, dev): -# class IndexTensorModel5(nn.Module): -# def __init__(self): -# super().__init__() +@tvm.testing.parametrize_targets("cuda") +def test_index_tensor5(target, dev): + class IndexTensorModel5(nn.Module): + def __init__(self): + super().__init__() -# def forward(self, x): -# return x[[[[1,2,4]]]] + def forward(self, x): + return x[[[[1,2,4]]]] -# torch_module = IndexTensorModel5().eval() -# raw_data = np.random.rand(5,5,5).astype("float32") -# assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + torch_module = IndexTensorModel5().eval() + raw_data = np.random.rand(5,5,5).astype("float32") + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) if __name__ == "__main__": tvm.testing.main() From 583f1e59e6372820ea29b46c9376c8852312ccc3 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Mon, 14 Apr 2025 03:40:31 -0400 Subject: [PATCH 081/105] all tests passgit statusgit status --- tests/python/relax/test_from_exported_OLD.py | 44 +++++++++++++++++++ .../relax/test_from_exported_to_cuda_NEW.py | 15 ------- 2 files changed, 44 insertions(+), 15 deletions(-) diff --git a/tests/python/relax/test_from_exported_OLD.py b/tests/python/relax/test_from_exported_OLD.py index df67072f7851..289281d7648a 100644 --- a/tests/python/relax/test_from_exported_OLD.py +++ b/tests/python/relax/test_from_exported_OLD.py @@ -171,5 +171,49 @@ def forward(self, x): raw_data = np.random.rand(5,5,5).astype("float32") assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + +@tvm.testing.parametrize_targets("cuda") +def test_index_tensor6(target, dev): + class IndexTensorModel6(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x[[[0,1],[0,1]]] + + torch_module = IndexTensorModel6().eval() + raw_data = np.random.rand(5,5,5,5).astype("float32") + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + +@tvm.testing.parametrize_targets("cuda") +def test_index_tensor7(target, dev): + class IndexTensorModel7(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x[[[0,1,2,3], [1,2,3,4], [2,3,4,0]]] # both args[0] and indices are expr.Var + + torch_module = IndexTensorModel7().eval() + raw_data = np.random.rand(5,5,5,5).astype("float32") + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + +@tvm.testing.parametrize_targets("cuda") +def test_index_tensor8(target, dev): + class IndexTensorModel8(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x[[[[0,1],[2,3]],[[2,3],[3,4]],[[2,4],[1,2]],[[0,4],[0,3]]]] + + torch_module = IndexTensorModel8().eval() + raw_data = np.random.rand(5,5,5,5).astype("float32") + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/test_from_exported_to_cuda_NEW.py b/tests/python/relax/test_from_exported_to_cuda_NEW.py index dd449c344012..8744af9a82f7 100644 --- a/tests/python/relax/test_from_exported_to_cuda_NEW.py +++ b/tests/python/relax/test_from_exported_to_cuda_NEW.py @@ -63,20 +63,5 @@ def assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, tar np.testing.assert_allclose(actual=actual, desired=desired, rtol=1e-5, atol=1e-5) - -@tvm.testing.parametrize_targets("cuda") -def test_index_tensor6(target, dev): - class IndexTensorModel6(nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - return x[[[0,1],[0,1]]] - - torch_module = IndexTensorModel6().eval() - raw_data = np.random.rand(5,5,5,5).astype("float32") - assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) - - if __name__ == "__main__": tvm.testing.main() From 4ce35b9edf61374a9ebcf71a45627ca7ee998d04 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Mon, 14 Apr 2025 03:46:37 -0400 Subject: [PATCH 082/105] all tests still psas --- include/tvm/relax/attrs/manipulate.h | 9 -- include/tvm/topi/transform.h | 59 ------- .../transform/legalize_ops/manipulate.py | 2 +- python/tvm/topi/broadcast.py | 3 +- python/tvm/topi/transform.py | 148 +----------------- src/relax/op/tensor/manipulate.cc | 80 ---------- src/topi/transform.cc | 4 - 7 files changed, 3 insertions(+), 302 deletions(-) diff --git a/include/tvm/relax/attrs/manipulate.h b/include/tvm/relax/attrs/manipulate.h index 258aa20d4703..e6c16d233a6b 100644 --- a/include/tvm/relax/attrs/manipulate.h +++ b/include/tvm/relax/attrs/manipulate.h @@ -169,15 +169,6 @@ struct GatherNDAttrs : public tvm::AttrsNode { } }; // struct GatherNDAttrs -// TODO maybe we don't need this? -/*! \brief Attributes used in index_tensor operators */ -// struct IndexTensorAttrs : public tvm::AttrsNode { -// Array indices; // TODO will need to extend this, since could be an array of arrays? -// TVM_DECLARE_ATTRS(IndexTensorAttrs, "relax.attrs.IndexTensorAttrs") { -// TVM_ATTR_FIELD(indices).describe("The indices to select."); -// } -// }; // struct IndexTensorAttrs - /*! \brief Attributes used in scatter_elements operators */ struct ScatterElementsAttrs : public tvm::AttrsNode { Integer axis; diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index 51bddd132834..762148dcfac3 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -519,65 +519,6 @@ inline Tensor concatenate(const Array& inputs, int axis = 0, std::string name, tag); } -/*! - * \brief Join a sequence of tensors along an existing axis - * - * \param inputs The input tensors - * \param axis The axis along which the tensors will be joined - * \param name The name of the operation - * \param tag The tag to mark the operation - * - * \return A Tensor whose op member is the concatenate operation - */ -inline Tensor concatenate2(const Tensor& first, const Array& inputs, int axis = 0, - std::string name = "T_concat", std::string tag = kInjective) { - int ndim = static_cast(inputs[0]->shape.size()); - ICHECK(-ndim <= axis && axis < ndim) << "concatenate only accepts `axis` in [-ndim, ndim)" - << ", but got axis = " << axis << ", and ndim = " << ndim; - if (axis < 0) { - axis += ndim; - } - ICHECK_LT(axis, inputs[0]->shape.size()) << "axis out of bounds"; - - Array axis_sizes; - for (auto t : inputs) { - axis_sizes.push_back(t->shape[axis]); - } - arith::Analyzer analyzer; - PrimExpr join_size = axis_sizes[0]; - for (size_t i = 1; i < axis_sizes.size(); ++i) { - join_size += axis_sizes[i]; - } - join_size = analyzer.Simplify(join_size); - Array out_shape; - for (size_t i = 0; i < inputs[0]->shape.size(); ++i) { - out_shape.push_back(i == static_cast(axis) ? join_size : inputs[0]->shape[i]); - } - - return compute( - out_shape, - [&](const Array& indices) { - auto ret = inputs[0](indices); - auto ind = indices[axis]; - for (size_t i = 0; i < inputs.size() - 1; ++i) { - ind -= axis_sizes[i]; - - Array idx; - for (size_t i = 0; i < static_cast(axis); ++i) { - idx.push_back(indices[i]); - } - idx.push_back(ind); - for (size_t i = axis + 1; i < indices.size(); ++i) { - idx.push_back(indices[i]); - } - - ret = tvm::if_then_else(ind >= 0, inputs[i + 1](idx), ret); - } - return ret; - }, - name, tag); -} - /*! * \brief Join a sequence of tensors along a new axis. * diff --git a/python/tvm/relax/transform/legalize_ops/manipulate.py b/python/tvm/relax/transform/legalize_ops/manipulate.py index 6dc8e8283b4a..42e6cd3b0fa8 100644 --- a/python/tvm/relax/transform/legalize_ops/manipulate.py +++ b/python/tvm/relax/transform/legalize_ops/manipulate.py @@ -87,7 +87,7 @@ def _concat2(bb: BlockBuilder, call: Call) -> Expr: t.fields if isinstance(t, Tuple) else [bb.emit(TupleGetItem(t, i)) for i in range(n_field)] ) return bb.call_te( - topi.concatenate2, call.args[0], fields, None if call.attrs.axis is None else call.attrs.axis.value + topi.index_tensor, call.args[0], fields, None if call.attrs.axis is None else call.attrs.axis.value ) diff --git a/python/tvm/topi/broadcast.py b/python/tvm/topi/broadcast.py index 597c7f24d4c2..2b350ff817d9 100644 --- a/python/tvm/topi/broadcast.py +++ b/python/tvm/topi/broadcast.py @@ -56,8 +56,7 @@ def add(lhs, rhs): Returns Expr if both operands are Expr. Otherwise returns Tensor. """ - return lhs - # return _cpp.add(lhs, rhs) # TODO revert + return _cpp.add(lhs, rhs) def subtract(lhs, rhs): diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py index 8f773da47bf5..8adcb624db0d 100644 --- a/python/tvm/topi/transform.py +++ b/python/tvm/topi/transform.py @@ -402,24 +402,6 @@ def concatenate(a_tuple, axis=0): """ return cpp.concatenate(a_tuple, axis) -def concatenate2(first, a_tuple, axis=0): - """Join a sequence of arrays along an existing axis. - - Parameters - ---------- - a_tuple : tuple of tvm.te.Tensor - The arrays to concatenate - - axis : int, optional - The axis along which the arrays will be joined. Default is 0. - - Returns - ------- - ret : tvm.te.Tensor - """ - # return cpp.concatenate2(first, a_tuple, axis) - return topi.adv_index(first, a_tuple) - def stack(a, axis): """Repeats the whole array multiple times. @@ -1082,137 +1064,9 @@ def _apply_trilu(*indices): return te.compute(data.shape, _apply_trilu, name="trilu", tag=topi.tag.ELEMWISE) -def index_tensor(data, indices): +def index_tensor(data, indices, axis): # TODO remove axis argument """ TODO docstring - - If 'indices' is a list/tuple of length > 1, we interpret that as multiple advanced indices, - and implement with topi.adv_index (plus negative-index correction if desired). - - Otherwise, interpret 'indices' as a single advanced index, and implement with topi.take. - - Replicate data[indices] using only: - - basic indexing on data - - torch.index_select - - concatenation/stack - - broadcasting - … and no advanced indexing. - - Approach for multiple advanced indices: broadcast and loop - - Approach for single advanced index: - 1. Convert the nested Python list to a LongTensor. - 2. Remove exactly one leading dimension of size=1, if present. (Matches PyTorch's shape rule.) - 3. Flatten -> fix negative indices -> index_select -> reshape. """ return topi.adv_index(data, indices) - # print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") - # print("WE ARE IN TOPI~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") - # print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") - # print("type of data", type(data)) - # print("type of indices", type(indices)) - # print("data", data) - # print("indices", indices) - - # is_instance = isinstance(indices, (list, tuple)) - # print("isinstance(indices, (list, tuple))", is_instance) - # if is_instance: - # print("len(indices)", len(indices)) - - - # if isinstance(indices, (list, tuple)) and len(indices) > 1: - # print("IF CASE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") - - # def _broadcast_shape(shapes): - # """ - # shapes: list of tuples - # Return the broadcasted shape for these shapes - # """ - # max_ndim = max(len(s) for s in shapes) - # out_rev = [] - # # reverse each shape - # rev_shapes = [s[::-1] for s in shapes] - # for i in range(max_ndim): - # dim_size = 1 - # for rsh in rev_shapes: - # if i < len(rsh): - # s_ = rsh[i] - # # typical broadcast rule - # if s_ != 1 and dim_size != 1 and s_ != dim_size: - # raise ValueError("Incompatible shapes for broadcast") - # dim_size = max(dim_size, s_) - # out_rev.append(dim_size) - # out_rev.reverse() - # return tuple(out_rev) - - # shapes = [tuple(idx.shape) for idx in idx_list] - # broadcast_shape = _broadcast_shape(shapes) - - # # ------------------------------------------------- - # # 2) Expand (broadcast) each index to shape B - # # Then fix negative indices if you want negative support - # # ------------------------------------------------- - # expanded_idx_list = [] - # for i, idx in enumerate(idx_list): - # # broadcast to shape B - # broadcasted = topi.broadcast_to(idx, broadcast_shape) - - # # fix negative: out_idx = where(idx < 0, idx + data.shape[i], idx) - # # data.shape[i] might be a PrimExpr or int - # dim_size_i = data.shape[i] # dimension size for data's i-th dim - # # We must make sure it's broadcast-compatible: - # dim_size_t = topi.full_like(broadcasted, dim_size_i) - # zero_t = topi.full_like(broadcasted, 0) - # fixed = topi.where(topi.less(broadcasted, zero_t), - # topi.add(broadcasted, dim_size_t), - # broadcasted) - # expanded_idx_list.append(fixed) - - # # leftover dimensions => data.shape[k:] - # k = len(idx_list) - # leftover_dims = data.shape[k:] - # # Final output shape is broadcast_shape + leftover_dims - # final_shape = broadcast_shape + leftover_dims - - # # ------------------------------------------------- - # # 3) Build a te.compute that gathers from 'data' - # # ------------------------------------------------- - # def _compute(*args): - # # 'args' is a multi-index into final_shape - # # => the first len(broadcast_shape) are the broadcast coords - # # the remaining correspond to leftover_dims - # bdim = len(broadcast_shape) - # leftover_dim = len(leftover_dims) - # assert len(args) == bdim + leftover_dim - - # # advanced_indices for dimension i - # # i.e. i0 = expanded_idx_list[0][b0,b1,...], i1 = expanded_idx_list[1][b0,b1,...], ... - # # leftover coords => the last leftover_dim from 'args' - # # data is presumably shape = [D0, D1, ..., D(k-1), leftover...] - # # So data coordinate is [ i0, i1, ..., i(k-1), leftover0, leftover1, ...] - # data_coords = [] - # for i_ in range(k): - # data_coords.append(expanded_idx_list[i_][*args[:bdim]]) - # # Now append leftover coords - # data_coords.extend(args[bdim:]) - - # return data(*data_coords) - - # # The final te.compute - # out = te.compute( - # final_shape, - # _compute, - # name="multi_adv_index_gather", - # ) - # return out - - # else: - # # flattened = topi.reshape(indices, (-1,)) - # print("ELSE CASE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") - # idxs = indices[0] - # print("type(data)", type(data)) - # print("data", data) - # print("type(idxs)",type(idxs)) - # print("idxs", idxs) - # picked = topi.take(data, idxs, axis=0) - # return picked - diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index ba2cf272f71f..2e184724c525 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -408,7 +408,6 @@ StructInfo InferStructInfoConcat2(const Call& call, const BlockBuilder& ctx) { } const auto* attrs = call->attrs.as(); - int output_ndim = attrs->axis.defined() ? kUnknownNDim : 1; DataType output_dtype = data_sinfo->dtype; bool vdevice_unknown = false; @@ -441,85 +440,6 @@ StructInfo InferStructInfoConcat2(const Call& call, const BlockBuilder& ctx) { } else { return TensorStructInfo(output_dtype, indices_sinfo->ndim, indices_sinfo->vdevice); } - - // std::vector> shape_values; - // shape_values.reserve(tensor_sinfo.size()); - - - // // TensorStructInfo data_sinfo = GetInputTensorStructInfo(call, 0, ctx); - // // DataType output_dtype = data_sinfo->dtype; - - // // return TensorStructInfo(output_dtype, kUnknownNDim); - - // for (TensorStructInfo sinfo : tensor_sinfo) { // TODO need this for-loop! - - // // Update the output ndim. - // // Todo(relax-team): revisit here for better check on if the input tensor has - // // ndim 1 when the input axis is undefined. - // if (output_ndim == kUnknownNDim) { - // output_ndim = sinfo->ndim; - // } else if (sinfo->ndim != kUnknownNDim && sinfo->ndim != output_ndim) { - // ctx->ReportFatal(Diagnostic::Error(call) - // << "Concat expects all input tensors to have same ndim. However, the " - // "input contains tensors with ndim " - // << output_ndim << " and " << sinfo->ndim); - // } - - // // Update the shape values for best effort check. - // const auto* shape_expr = sinfo->shape.as(); - // if (shape_expr != nullptr) { - // shape_values.push_back(shape_expr->values); - // continue; - // } - // shape_unknown = true; - - // if (!sinfo->shape.defined()) { - // continue; - // } - // // Keep the shape value for equality check. - // ShapeStructInfo shape_sinfo = Downcast(sinfo->shape.value()->struct_info_); - // if (shape_sinfo->values.defined()) { - // shape_values.push_back(shape_sinfo->values.value()); - // } - // } - - // if (is_void_dtype) { - // output_dtype = DataType::Void(); - // } - - // if (output_ndim == kUnknownNDim) { - // return tensor_sinfo.size() == 1 ? tensor_sinfo[0] - // : TensorStructInfo(output_dtype, output_ndim, vdev); - // } - - // int axis = - // attrs->axis.defined() ? NormalizeAxis(call, ctx, output_ndim, attrs->axis.value()->value) : 0; - // // If there is only one input tensor, no action is needed. - // if (tensor_sinfo.size() == 1) { - // return tensor_sinfo[0]; - // } - // if (shape_values.empty()) { - // if (!vdevice_unknown) { - // return TensorStructInfo(output_dtype, output_ndim, vdev); - // } - // return TensorStructInfo(output_dtype, output_ndim); - // } - - // // TODO why do I need to output_shape here?? index_tensor did not require it! - // // As long as the there is known shape value, we will do the best effort check to ensure safety. - // Optional> output_shape = CheckConcatOutputShape2(call, ctx, shape_values, axis); - - // if (shape_unknown || !output_shape.defined()) { - // if (!vdevice_unknown) { - // return TensorStructInfo(output_dtype, output_ndim, vdev); - // } - // return TensorStructInfo(output_dtype, output_ndim); - // } else { - // if (!vdevice_unknown) { - // return TensorStructInfo(ShapeExpr(output_shape.value()), output_dtype, vdev); - // } - // return TensorStructInfo(ShapeExpr(output_shape.value()), output_dtype); - // } } TVM_REGISTER_OP("relax.concat2") diff --git a/src/topi/transform.cc b/src/topi/transform.cc index 426d8740ef78..7ef63a9b3f56 100644 --- a/src/topi/transform.cc +++ b/src/topi/transform.cc @@ -70,10 +70,6 @@ TVM_REGISTER_GLOBAL("topi.concatenate").set_body([](TVMArgs args, TVMRetValue* r *rv = concatenate(args[0], args[1]); }); -TVM_REGISTER_GLOBAL("topi.concatenate2").set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = concatenate2(args[0], args[1], args[2]); -}); - TVM_REGISTER_GLOBAL("topi.stack").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = stack(args[0], args[1]); }); From 6cf94b68b360e971e214c30658c092c0bcc6ff52 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Mon, 14 Apr 2025 03:50:50 -0400 Subject: [PATCH 083/105] passes every single test --- .../torch/base_fx_graph_translator.py | 36 --- .../torch/exported_program_translator.py | 4 +- src/relax/op/tensor/manipulate.cc | 241 ++++++++---------- 3 files changed, 110 insertions(+), 171 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index f1529e04f281..8a77561fba9e 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -93,16 +93,12 @@ def _retrieve_args(self, node): from torch import fx if isinstance(node, fx.Node): - print("isinstance(node, fx.Node)") return self.env[node] elif isinstance(node, tuple): - print("isinstance(node, tuple) of length", len(node)) return tuple(self._retrieve_args(x) for x in node) elif isinstance(node, list): - print("isinstance(node, list) of length", len(node)) return [self._retrieve_args(x) for x in node] elif isinstance(node, dict): - print("isinstance(node, dict)") return {self._retrieve_args(k): self._retrieve_args(v) for k, v in node.items()} else: return node @@ -987,11 +983,6 @@ def _cat(self, node: fx.Node) -> relax.Var: axis = args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0) return self.block_builder.emit(relax.op.concat(args[0], axis=axis)) - def _cat2(self, node: fx.Node) -> relax.Var: - args = self.retrieve_args(node) - axis = args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0) - return self.block_builder.emit(relax.op.concat2(args[0][0], args[0], axis=axis)) - def _chunk(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] chunks = node.args[1] @@ -1080,37 +1071,10 @@ def _gather(self, node: fx.Node) -> relax.Var: return self.block_builder.emit(relax.op.gather_elements(x, index, axis=dim)) def _index_tensor(self, node: fx.Node) -> relax.Var: - # TODO should I be using _binary_op() ? - -# ? x = self.env[node.args[0]] - # indices = node.args[1] args = self.retrieve_args(node) - print("node: fx.Node passed to index_tensor:") - print("len of args", len(args)) - print("type of args[0]", type(args[0])) - print("args[0]", args[0]) - print("type of args[1]", type(args[1])) # Is a list no matter what!!! Like even if we pass a torch.tensor - print("args[1]", args[1]) - - # indices = args[1] # TODO do something like this! - # indices = [2,3] indices = args[1] - print("type of indices", type(indices)) - # print("indices:") - # args_indices = self.retrieve_args(indices) - # print("len of args_indices", len(args_indices)) - # print("type of args_indices[0]", type(args_indices[0])) - # print("args_indices[0]", args_indices[0]) - # print("type of args_indices[1]", type(args_indices[1])) # Is a list no matter what!!! Like even if we pass a torch.tensor - # print("args_indices[1]", args_indices[1]) - - - # index = self.env[node.args[1]] # TODO - # return self.block_builder.emit(relax.op.index_tensor(args[0], indices)) # TODO revert! removed to test collapse sum like - # return self.block_builder.emit(relax.op.index_tensor(args[0], indices)) return self.block_builder.emit(relax.op.concat2(args[0], indices)) - def _permute(self, node: fx.Node) -> relax.Var: import torch # type: ignore diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index c3d17e57b7ad..a5a76c77e921 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -390,10 +390,10 @@ def create_convert_map( "where.self": self._where, # tensor manipulation "argsort.default": self._argsort, - "cat.default": self._cat2, + "cat.default": self._cat, "chunk.default": self._chunk, "clamp.Tensor": self._clamp, - "concat.default": self._cat2, + "concat.default": self._cat, "copy_.default": self._copy_, "cumsum.default": self._cumsum, "cumprod.default": self._cumprod, diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index 2e184724c525..5c669f46eb02 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -330,126 +330,6 @@ TVM_REGISTER_OP("relax.concat") .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", Bool(true)); -/* relax.concat2 */ - -Expr concat2(Expr first, Expr tensors, Optional axis) { - ObjectPtr attrs = make_object(); - attrs->axis = std::move(axis); - - static const Op& op = Op::Get("relax.concat2"); - return Call(op, {std::move(first), std::move(tensors)}, Attrs(attrs), {}); -} - -TVM_REGISTER_GLOBAL("relax.op.concat2").set_body_typed(concat2); - -Optional> CheckConcatOutputShape2(const Call& call, const BlockBuilder& ctx, - const std::vector>& shape_values, - int axis) { - bool shape_unknown = false; - arith::Analyzer* analyzer = ctx->GetAnalyzer(); - PrimExpr concat_sum = [&]() { - // For the specified axis, we compute the sum of shape value over each tensor. - - // Special case, if all concatenated values have the same shape - StructuralEqual structural_equal; - PrimExpr first_concat_dim = shape_values[1][axis]; - bool all_same = std::all_of(shape_values.begin(), shape_values.end(), [&](const auto& a) { - return structural_equal(a[axis], first_concat_dim); - }); - if (all_same) { - return first_concat_dim * IntImm(DataType::Int(64), shape_values.size()); - } - - // General case, add up the dimensions along the specified axis. - PrimExpr concat_sum = IntImm(DataType::Int(64), 0); - for (Array shape_value : shape_values) { - concat_sum += shape_value[axis]; - } - return concat_sum; - }(); - - // For other axes, we check the equality of all tensors' shape values, to ensure safety. - for (int d = 0; d < static_cast(shape_values[0].size()); ++d) { - if (d == axis) { - continue; - } - for (int i = 1; i < static_cast(shape_values.size()); ++i) { - if (analyzer->CanProve(shape_values[i][d] != shape_values[0][d])) { - ctx->ReportFatal(Diagnostic::Error(call) - << "Concat expects the input tensors to have the same shape on every " - "dimension except the one indicated by the input axis. However, the " - "input contains tensors whose shapes on dimension " - << d << " is " << shape_values[0][d] << " and " << shape_values[i][d]); - } else if (!analyzer->CanProveEqual(shape_values[i][d], shape_values[0][d])) { - shape_unknown = true; - } - } - } - - if (shape_unknown) { - return NullOpt; - } - Array output_shape = shape_values[0]; - output_shape.Set(axis, concat_sum); - return output_shape; -} - -StructInfo InferStructInfoConcat2(const Call& call, const BlockBuilder& ctx) { - if (call->args.size() != 2) { - ctx->ReportFatal(Diagnostic::Error(call) << "Concat op should have 1 argument"); - } - TensorStructInfo data_sinfo = GetInputTensorStructInfo(call, 0, ctx); - - Array tensor_sinfo = GetTensorStructInfoFromTuple(call, ctx, call->args[1]); - if (tensor_sinfo.empty()) { - ctx->ReportFatal(Diagnostic::Error(call) - << "Concat op expects at least one tensor in the input Tuple. However, the " - "given input Tuple is empty."); - } - - const auto* attrs = call->attrs.as(); - DataType output_dtype = data_sinfo->dtype; - - bool vdevice_unknown = false; - Optional vdev = NullOpt; - if (data_sinfo->vdevice.defined()) { - vdev = data_sinfo->vdevice.value(); - vdevice_unknown = true; - } - - bool shape_unknown = false; - bool is_void_dtype = false; - if (data_sinfo->dtype.is_void()) { - is_void_dtype = true; - } - - TensorStructInfo indices_sinfo = data_sinfo; - - Optional> data_shape_value; - if (data_sinfo->shape.defined()) { - data_shape_value = GetStructInfoAs(data_sinfo->shape.value())->values; - } - Optional> indices_shape_value; - if (indices_sinfo->shape.defined()) { - indices_shape_value = - GetStructInfoAs(indices_sinfo->shape.value())->values; - } - - if (indices_sinfo->shape.defined()) { - return TensorStructInfo(indices_sinfo->shape.value(), output_dtype, indices_sinfo->vdevice); - } else { - return TensorStructInfo(output_dtype, indices_sinfo->ndim, indices_sinfo->vdevice); - } -} - -TVM_REGISTER_OP("relax.concat2") - .set_attrs_type() - .set_num_inputs(2) - .add_argument("first", "Tensor", "The first tensor") - .add_argument("tensors", "Tuple of Tensors", "The input list of tensors.") - .set_attr("FInferStructInfo", InferStructInfoConcat2) // TODO necessary - .set_attr("FPurity", Bool(true)); - /* relax.expand_dims */ TVM_REGISTER_NODE_TYPE(ExpandDimsAttrs); @@ -596,30 +476,125 @@ TVM_REGISTER_OP("relax.flatten") .set_attr("FPurity", Bool(true)); /* relax.index_tensor */ -Expr index_tensor(Expr data, Expr indices) { - // TODO do we need code below? - // ObjectPtr attrs = make_object(); - // attrs->indices = std::move(indices); - static const Op& op = Op::Get("relax.index_tensor"); - return Call(op, {std::move(data), std::move(indices)}, Attrs(), {}); +/* relax.concat2 */ + +Expr concat2(Expr first, Expr tensors, Optional axis) { + ObjectPtr attrs = make_object(); + attrs->axis = std::move(axis); + + static const Op& op = Op::Get("relax.concat2"); + return Call(op, {std::move(first), std::move(tensors)}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relax.op.index_tensor").set_body_typed(index_tensor); +TVM_REGISTER_GLOBAL("relax.op.concat2").set_body_typed(concat2); + +Optional> CheckConcatOutputShape2(const Call& call, const BlockBuilder& ctx, + const std::vector>& shape_values, + int axis) { + bool shape_unknown = false; + arith::Analyzer* analyzer = ctx->GetAnalyzer(); + PrimExpr concat_sum = [&]() { + // For the specified axis, we compute the sum of shape value over each tensor. + + // Special case, if all concatenated values have the same shape + StructuralEqual structural_equal; + PrimExpr first_concat_dim = shape_values[1][axis]; + bool all_same = std::all_of(shape_values.begin(), shape_values.end(), [&](const auto& a) { + return structural_equal(a[axis], first_concat_dim); + }); + if (all_same) { + return first_concat_dim * IntImm(DataType::Int(64), shape_values.size()); + } + + // General case, add up the dimensions along the specified axis. + PrimExpr concat_sum = IntImm(DataType::Int(64), 0); + for (Array shape_value : shape_values) { + concat_sum += shape_value[axis]; + } + return concat_sum; + }(); + + // For other axes, we check the equality of all tensors' shape values, to ensure safety. + for (int d = 0; d < static_cast(shape_values[0].size()); ++d) { + if (d == axis) { + continue; + } + for (int i = 1; i < static_cast(shape_values.size()); ++i) { + if (analyzer->CanProve(shape_values[i][d] != shape_values[0][d])) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Concat expects the input tensors to have the same shape on every " + "dimension except the one indicated by the input axis. However, the " + "input contains tensors whose shapes on dimension " + << d << " is " << shape_values[0][d] << " and " << shape_values[i][d]); + } else if (!analyzer->CanProveEqual(shape_values[i][d], shape_values[0][d])) { + shape_unknown = true; + } + } + } + + if (shape_unknown) { + return NullOpt; + } + Array output_shape = shape_values[0]; + output_shape.Set(axis, concat_sum); + return output_shape; +} -StructInfo InferStructInfoIndexTensor(const Call& call, const BlockBuilder& ctx) { +StructInfo InferStructInfoConcat2(const Call& call, const BlockBuilder& ctx) { + if (call->args.size() != 2) { + ctx->ReportFatal(Diagnostic::Error(call) << "Concat op should have 1 argument"); + } TensorStructInfo data_sinfo = GetInputTensorStructInfo(call, 0, ctx); + Array tensor_sinfo = GetTensorStructInfoFromTuple(call, ctx, call->args[1]); + if (tensor_sinfo.empty()) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Concat op expects at least one tensor in the input Tuple. However, the " + "given input Tuple is empty."); + } + + const auto* attrs = call->attrs.as(); DataType output_dtype = data_sinfo->dtype; - return TensorStructInfo(output_dtype, kUnknownNDim); + bool vdevice_unknown = false; + Optional vdev = NullOpt; + if (data_sinfo->vdevice.defined()) { + vdev = data_sinfo->vdevice.value(); + vdevice_unknown = true; + } + + bool shape_unknown = false; + bool is_void_dtype = false; + if (data_sinfo->dtype.is_void()) { + is_void_dtype = true; + } + + TensorStructInfo indices_sinfo = data_sinfo; + + Optional> data_shape_value; + if (data_sinfo->shape.defined()) { + data_shape_value = GetStructInfoAs(data_sinfo->shape.value())->values; + } + Optional> indices_shape_value; + if (indices_sinfo->shape.defined()) { + indices_shape_value = + GetStructInfoAs(indices_sinfo->shape.value())->values; + } + + if (indices_sinfo->shape.defined()) { + return TensorStructInfo(indices_sinfo->shape.value(), output_dtype, indices_sinfo->vdevice); + } else { + return TensorStructInfo(output_dtype, indices_sinfo->ndim, indices_sinfo->vdevice); + } } -TVM_REGISTER_OP("relax.index_tensor") +TVM_REGISTER_OP("relax.concat2") + .set_attrs_type() // TODO remove that .set_num_inputs(2) - .add_argument("data", "Tensor", "The input tensor.") - .add_argument("indices", "Tuple of Tensor", "The indices tensor.") - .set_attr("FInferStructInfo", InferStructInfoIndexTensor) + .add_argument("first", "Tensor", "The first tensor") + .add_argument("tensors", "Tuple of Tensors", "The input list of tensors.") + .set_attr("FInferStructInfo", InferStructInfoConcat2) // TODO necessary .set_attr("FPurity", Bool(true)); /* relax.layout_transform */ From e5e9688e44139e0d99b8d64136a495c5754f02da Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Mon, 14 Apr 2025 04:04:00 -0400 Subject: [PATCH 084/105] won't compile anymore --- .../torch/exported_program_translator.py | 1 - python/tvm/relax/op/__init__.py | 1 - python/tvm/relax/op/manipulate.py | 54 +++++-------- python/tvm/relax/op/op_attrs.py | 6 -- .../transform/legalize_ops/manipulate.py | 28 +------ python/tvm/topi/transform.py | 11 --- src/relax/op/tensor/manipulate.cc | 80 +++---------------- 7 files changed, 34 insertions(+), 147 deletions(-) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index a5a76c77e921..7fafdae7d4fe 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -544,7 +544,6 @@ def from_exported_program( assert ( func_name in self.convert_map ), f"Unsupported function type {func_name}" - print("found function", func_name, "!!!!!!!!!!!!!!!!!!!!!!!!!!!!") self.env[node] = self.convert_map[func_name](node) else: raise ValueError(f"Unsupported op {node.op}") diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py index f9a92a9549da..f81c5448a024 100644 --- a/python/tvm/relax/op/__init__.py +++ b/python/tvm/relax/op/__init__.py @@ -89,7 +89,6 @@ collapse_sum_like, collapse_sum_to, concat, - concat2, expand_dims, flatten, flip, diff --git a/python/tvm/relax/op/manipulate.py b/python/tvm/relax/op/manipulate.py index 9a3d5e8e7760..e70ddecd563f 100644 --- a/python/tvm/relax/op/manipulate.py +++ b/python/tvm/relax/op/manipulate.py @@ -70,31 +70,6 @@ def concat(tensors: Union[Expr, List[Expr]], axis: Optional[int] = 0) -> Expr: tensors = RxTuple(tensors) return _ffi_api.concat(tensors, axis) # type: ignore - -def concat2(first:Expr, tensors: Union[Expr, List[Expr]], axis: Optional[int] = 0) -> Expr: - """Concatenate the input tensors along the given axis. - - Parameters - ---------- - tensors : Union[relax.Expr, List[relax.Expr]] - An Expr in Tuple type, containing the tensors to be concatenated, - or a list of Tensors. - - axis : Optional[int] - The axis along which the tensors are concatenated. - If `axis` is `None`, the input tensor is required to be flattened before concatenation. - - Returns - ------- - result: relax.Expr - The concatenated tensor. - """ - if isinstance(tensors, (list, tuple)): - tensors = RxTuple(tensors) - # return _ffi_api.concat2(tensors, axis) # TODO this works for some reason! - return _ffi_api.concat2(first, tensors, axis) # type: ignore - - def expand_dims(x: Expr, axis: Union[int, List[int]]) -> Expr: """Insert new axes at the positions given by `axis`. @@ -531,18 +506,29 @@ def gather_nd(data: Expr, indices: Expr, batch_dims: int = 0) -> Expr: return _ffi_api.gather_nd(data, indices, batch_dims) # type: ignore -def index_tensor(data: Expr, indices: Union[Expr, List[Expr]]) -> Expr: - """ - TODO docstring +# TODO change names of args and remove axis arg +def index_tensor(data:Expr, indices: Union[Expr, List[Expr]], axis: Optional[int] = 0) -> Expr: + """Concatenate the input tensors along the given axis. + + Parameters + ---------- + tensors : Union[relax.Expr, List[relax.Expr]] + An Expr in Tuple type, containing the tensors to be concatenated, + or a list of Tensors. + + axis : Optional[int] + The axis along which the tensors are concatenated. + If `axis` is `None`, the input tensor is required to be flattened before concatenation. + + Returns + ------- + result: relax.Expr + The concatenated tensor. """ - # TODO loosen those assertions! Need to handler lists of lists of lists etc. - # assert isinstance(indices, list), f"indices should be a list, but is a {type(indices)}. Data is a {type(data)}" - # assert all(isinstance(i, int) for i in indices), "indices should be a list of integers, but got {}".format( - # [type(i) for i in indices] - # ) if isinstance(indices, (list, tuple)): indices = RxTuple(indices) - return _ffi_api.index_tensor(data, indices) # type: ignore + # return _ffi_api.concat2(tensors, axis) # TODO this works for some reason! + return _ffi_api.index_tensor(data, indices, axis) # type: ignore def scatter_elements( data: Expr, indices: Expr, updates: Expr, axis: int = 0, reduction: str = "update" diff --git a/python/tvm/relax/op/op_attrs.py b/python/tvm/relax/op/op_attrs.py index 58fe6b2f5f6d..4658950f511a 100644 --- a/python/tvm/relax/op/op_attrs.py +++ b/python/tvm/relax/op/op_attrs.py @@ -182,9 +182,3 @@ class EinsumAttrs(Attrs): @tvm._ffi.register_object("relax.attrs.FlipAttrs") class FlipAttrs(Attrs): """Attributes for flip operator""" - - -# TODO is this needed? It looks like not all ops are here -# @tvm._ffi.register_object("relax.attrs.IndexTensorAttrs") -# class IndexTensorAttrs(Attrs): -# """Attributes used in index_tensor operator""" diff --git a/python/tvm/relax/transform/legalize_ops/manipulate.py b/python/tvm/relax/transform/legalize_ops/manipulate.py index 42e6cd3b0fa8..6aa150183ff5 100644 --- a/python/tvm/relax/transform/legalize_ops/manipulate.py +++ b/python/tvm/relax/transform/legalize_ops/manipulate.py @@ -72,24 +72,6 @@ def _concat(bb: BlockBuilder, call: Call) -> Expr: ) -@register_legalize("relax.concat2") -def _concat2(bb: BlockBuilder, call: Call) -> Expr: - t = call.args[1] - n_field = len(t.struct_info.fields) - while isinstance(t, Var): - binding = bb.lookup_binding(t) - if not isinstance(binding, (Tuple, Var)): - break - t = binding - - assert isinstance(t, (Tuple, Var)) - fields = ( - t.fields if isinstance(t, Tuple) else [bb.emit(TupleGetItem(t, i)) for i in range(n_field)] - ) - return bb.call_te( - topi.index_tensor, call.args[0], fields, None if call.attrs.axis is None else call.attrs.axis.value - ) - @register_legalize("relax.expand_dims") def _expand_dims(bb: BlockBuilder, call: Call) -> Expr: @@ -181,11 +163,6 @@ def te_gather_nd(data, indices, batch_dims): return bb.call_te(te_gather_nd, call.args[0], call.args[1], int(call.attrs.batch_dims)) -# @register_legalize("relax.index_tensor") -# def _index_tensor(bb: BlockBuilder, call: Call) -> Expr: -# return bb.call_te(topi.index_tensor, call.args[0], call.args[1][0]) - - @register_legalize("relax.index_tensor") def _index_tensor(bb: BlockBuilder, call: Call) -> Expr: @@ -201,11 +178,10 @@ def _index_tensor(bb: BlockBuilder, call: Call) -> Expr: fields = ( t.fields if isinstance(t, Tuple) else [bb.emit(TupleGetItem(t, i)) for i in range(n_field)] ) - return bb.call_te( - topi.index_tensor, call.args[0], fields + return bb.call_te( # TODO remove axis + topi.index_tensor, call.args[0], fields, None if call.attrs.axis is None else call.attrs.axis.value ) - @register_legalize("relax.scatter_elements") def _scatter_elements(bb: BlockBuilder, call: Call) -> Expr: return bb.call_te( diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py index 8adcb624db0d..a37210b552a9 100644 --- a/python/tvm/topi/transform.py +++ b/python/tvm/topi/transform.py @@ -956,17 +956,6 @@ def adv_index(data, indices): result : tvm.te.Tensor Output tensor """ - - """ - TODO - this seems to be wrong - Does not achieve correctness with this: - - x np.random.rand(5,5,5,5).astype("float32") - return x[[[0,1],[0,1]]] - - """ - return cpp.adv_index(data, indices) diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index 5c669f46eb02..ce64c4aba779 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -477,84 +477,29 @@ TVM_REGISTER_OP("relax.flatten") /* relax.index_tensor */ -/* relax.concat2 */ - -Expr concat2(Expr first, Expr tensors, Optional axis) { - ObjectPtr attrs = make_object(); +Expr index_tensor(Expr first, Expr tensors, Optional axis) { + ObjectPtr attrs = make_object(); // TODO remove this attrs->axis = std::move(axis); - static const Op& op = Op::Get("relax.concat2"); + static const Op& op = Op::Get("relax.index_tensor"); return Call(op, {std::move(first), std::move(tensors)}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relax.op.concat2").set_body_typed(concat2); - -Optional> CheckConcatOutputShape2(const Call& call, const BlockBuilder& ctx, - const std::vector>& shape_values, - int axis) { - bool shape_unknown = false; - arith::Analyzer* analyzer = ctx->GetAnalyzer(); - PrimExpr concat_sum = [&]() { - // For the specified axis, we compute the sum of shape value over each tensor. - - // Special case, if all concatenated values have the same shape - StructuralEqual structural_equal; - PrimExpr first_concat_dim = shape_values[1][axis]; - bool all_same = std::all_of(shape_values.begin(), shape_values.end(), [&](const auto& a) { - return structural_equal(a[axis], first_concat_dim); - }); - if (all_same) { - return first_concat_dim * IntImm(DataType::Int(64), shape_values.size()); - } - - // General case, add up the dimensions along the specified axis. - PrimExpr concat_sum = IntImm(DataType::Int(64), 0); - for (Array shape_value : shape_values) { - concat_sum += shape_value[axis]; - } - return concat_sum; - }(); - - // For other axes, we check the equality of all tensors' shape values, to ensure safety. - for (int d = 0; d < static_cast(shape_values[0].size()); ++d) { - if (d == axis) { - continue; - } - for (int i = 1; i < static_cast(shape_values.size()); ++i) { - if (analyzer->CanProve(shape_values[i][d] != shape_values[0][d])) { - ctx->ReportFatal(Diagnostic::Error(call) - << "Concat expects the input tensors to have the same shape on every " - "dimension except the one indicated by the input axis. However, the " - "input contains tensors whose shapes on dimension " - << d << " is " << shape_values[0][d] << " and " << shape_values[i][d]); - } else if (!analyzer->CanProveEqual(shape_values[i][d], shape_values[0][d])) { - shape_unknown = true; - } - } - } - - if (shape_unknown) { - return NullOpt; - } - Array output_shape = shape_values[0]; - output_shape.Set(axis, concat_sum); - return output_shape; -} +TVM_REGISTER_GLOBAL("relax.op.index_tensor").set_body_typed(index_tensor); -StructInfo InferStructInfoConcat2(const Call& call, const BlockBuilder& ctx) { +StructInfo InferStructInfoIndexTensor(const Call& call, const BlockBuilder& ctx) { if (call->args.size() != 2) { - ctx->ReportFatal(Diagnostic::Error(call) << "Concat op should have 1 argument"); + ctx->ReportFatal(Diagnostic::Error(call) << "Index.Tensor op should have 2 arguments"); } TensorStructInfo data_sinfo = GetInputTensorStructInfo(call, 0, ctx); Array tensor_sinfo = GetTensorStructInfoFromTuple(call, ctx, call->args[1]); if (tensor_sinfo.empty()) { ctx->ReportFatal(Diagnostic::Error(call) - << "Concat op expects at least one tensor in the input Tuple. However, the " - "given input Tuple is empty."); + << "Index.Tensor expects at least one tensor in the input Tuple. However, the " + "given input Tuple is empty."); // TODO is this always true? } - const auto* attrs = call->attrs.as(); DataType output_dtype = data_sinfo->dtype; bool vdevice_unknown = false; @@ -564,7 +509,6 @@ StructInfo InferStructInfoConcat2(const Call& call, const BlockBuilder& ctx) { vdevice_unknown = true; } - bool shape_unknown = false; bool is_void_dtype = false; if (data_sinfo->dtype.is_void()) { is_void_dtype = true; @@ -589,12 +533,12 @@ StructInfo InferStructInfoConcat2(const Call& call, const BlockBuilder& ctx) { } } -TVM_REGISTER_OP("relax.concat2") +TVM_REGISTER_OP("relax.index_tensor") .set_attrs_type() // TODO remove that .set_num_inputs(2) - .add_argument("first", "Tensor", "The first tensor") - .add_argument("tensors", "Tuple of Tensors", "The input list of tensors.") - .set_attr("FInferStructInfo", InferStructInfoConcat2) // TODO necessary + .add_argument("data", "Tensor", "The input data.") + .add_argument("indices", "List of Tensors", "The indices used to index.") + .set_attr("FInferStructInfo", InferStructInfoIndexTensor) // TODO necessary .set_attr("FPurity", Bool(true)); /* relax.layout_transform */ From e83466191a6e170c9b84a63dff544d2a38e67ef5 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Mon, 14 Apr 2025 04:09:28 -0400 Subject: [PATCH 085/105] all tests pass --- .../relax/frontend/torch/base_fx_graph_translator.py | 2 +- python/tvm/relax/op/manipulate.py | 1 - src/relax/op/tensor/manipulate.cc | 12 ------------ src/relax/op/tensor/manipulate.h | 2 +- 4 files changed, 2 insertions(+), 15 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 8a77561fba9e..c047e3e8e07a 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1073,7 +1073,7 @@ def _gather(self, node: fx.Node) -> relax.Var: def _index_tensor(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) indices = args[1] - return self.block_builder.emit(relax.op.concat2(args[0], indices)) + return self.block_builder.emit(relax.op.index_tensor(args[0], indices)) def _permute(self, node: fx.Node) -> relax.Var: import torch # type: ignore diff --git a/python/tvm/relax/op/manipulate.py b/python/tvm/relax/op/manipulate.py index e70ddecd563f..ef3a9b653a47 100644 --- a/python/tvm/relax/op/manipulate.py +++ b/python/tvm/relax/op/manipulate.py @@ -527,7 +527,6 @@ def index_tensor(data:Expr, indices: Union[Expr, List[Expr]], axis: Optional[int """ if isinstance(indices, (list, tuple)): indices = RxTuple(indices) - # return _ffi_api.concat2(tensors, axis) # TODO this works for some reason! return _ffi_api.index_tensor(data, indices, axis) # type: ignore def scatter_elements( diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index ce64c4aba779..1aa98c4a8329 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -502,18 +502,6 @@ StructInfo InferStructInfoIndexTensor(const Call& call, const BlockBuilder& ctx) DataType output_dtype = data_sinfo->dtype; - bool vdevice_unknown = false; - Optional vdev = NullOpt; - if (data_sinfo->vdevice.defined()) { - vdev = data_sinfo->vdevice.value(); - vdevice_unknown = true; - } - - bool is_void_dtype = false; - if (data_sinfo->dtype.is_void()) { - is_void_dtype = true; - } - TensorStructInfo indices_sinfo = data_sinfo; Optional> data_shape_value; diff --git a/src/relax/op/tensor/manipulate.h b/src/relax/op/tensor/manipulate.h index 0c7d482ffacb..7aa641d6b971 100644 --- a/src/relax/op/tensor/manipulate.h +++ b/src/relax/op/tensor/manipulate.h @@ -211,7 +211,7 @@ Expr gather_nd(Expr data, Expr indices, int batch_dims = 0); * The output shape is batch_dims + indices.shape[:-1] + data.shape[batch_dims + * indices.shape[-1]:] */ -Expr index_tensor(Expr data, Expr indices); +Expr index_tensor(Expr data, Expr indices, Optional axis); // TODO remove axis /*! * \brief Scatter updates into an array according to indices. From 2f87583857e931d2041c9cc020ec87b9723dbf6d Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Mon, 14 Apr 2025 04:13:51 -0400 Subject: [PATCH 086/105] minimum manipulate.cc -> all tests pass --- src/relax/op/tensor/manipulate.cc | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index 1aa98c4a8329..786543a8dfb2 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -504,17 +504,8 @@ StructInfo InferStructInfoIndexTensor(const Call& call, const BlockBuilder& ctx) TensorStructInfo indices_sinfo = data_sinfo; - Optional> data_shape_value; - if (data_sinfo->shape.defined()) { - data_shape_value = GetStructInfoAs(data_sinfo->shape.value())->values; - } - Optional> indices_shape_value; - if (indices_sinfo->shape.defined()) { - indices_shape_value = - GetStructInfoAs(indices_sinfo->shape.value())->values; - } - - if (indices_sinfo->shape.defined()) { + if (indices_sinfo->shape.defined()) { // TODO need this condition, but not sure why! Isn't that + // not reflective of the output anyway? return TensorStructInfo(indices_sinfo->shape.value(), output_dtype, indices_sinfo->vdevice); } else { return TensorStructInfo(output_dtype, indices_sinfo->ndim, indices_sinfo->vdevice); From c2de1e2c682c44a85f25b21d3846dca4a8cc7ebf Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Mon, 14 Apr 2025 04:34:07 -0400 Subject: [PATCH 087/105] cleanup --- tests/python/relax/test_from_exported_OLD.py | 219 ------------------ .../python/relax/test_from_exported_concat.py | 86 ------- .../relax/test_from_exported_to_cuda.py | 20 ++ .../relax/test_from_exported_to_cuda_NEW.py | 67 ------ 4 files changed, 20 insertions(+), 372 deletions(-) delete mode 100644 tests/python/relax/test_from_exported_OLD.py delete mode 100644 tests/python/relax/test_from_exported_concat.py delete mode 100644 tests/python/relax/test_from_exported_to_cuda_NEW.py diff --git a/tests/python/relax/test_from_exported_OLD.py b/tests/python/relax/test_from_exported_OLD.py deleted file mode 100644 index 289281d7648a..000000000000 --- a/tests/python/relax/test_from_exported_OLD.py +++ /dev/null @@ -1,219 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import tvm -from tvm import relax -import tvm.testing -import numpy as np -import torch -from torch import nn -from torch.export import export -from tvm.relax.frontend.torch import from_exported_program -from torch.nn import Softmax, Upsample - - -def assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev): - """ - This util ensures that a torch module can successfully be exported to TVM - using torch.export and that the resuling IR program gives the same result - as PyTorch when ran on CUDA. - """ - raw_data_for_tvm = raw_data.copy() # In case the data is modified - torch_data = torch.from_numpy(raw_data) - example_args = (torch_data,) - - with torch.no_grad(): - exported_program = export(torch_module, example_args) - mod_from_torch = from_exported_program(exported_program, keep_params_as_input=True) - - tvm_mod, tvm_params = relax.frontend.detach_params(mod_from_torch) - - relax_pipeline = relax.get_default_pipeline(tvm.target.Target.from_device(tvm.cuda())) - ex = relax.build(tvm_mod, target=target, relax_pipeline=relax_pipeline) - vm = relax.VirtualMachine(ex, dev) - - gpu_data = tvm.nd.array(raw_data_for_tvm, dev) - gpu_params = [tvm.nd.array(p, dev) for p in tvm_params["main"]] - gpu_out = vm["main"](gpu_data, *gpu_params) - - pytorch_out = torch_module(torch_data) - - if isinstance(pytorch_out, tuple): - for i in range(len(pytorch_out)): - actual = gpu_out[i].numpy() - desired = pytorch_out[i].detach().numpy() - np.testing.assert_allclose(actual=actual, desired=desired, rtol=1e-5, atol=1e-5) - else: - actual = gpu_out[0].numpy() - desired = pytorch_out.detach().numpy() - np.testing.assert_allclose(actual=actual, desired=desired, rtol=1e-5, atol=1e-5) - - - -@tvm.testing.parametrize_targets("cuda") -def test_full(target, dev): - class FullModel(nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - return torch.full((2, 3), 3.141592) - - torch_module = FullModel().eval() - - raw_data = np.random.rand(3,3).astype("float32") - - assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) - - -# Test index.Tensor # TODO aggregate into one big tet - -@tvm.testing.parametrize_targets("cuda") -def test_index_tensor0(target, dev): - class IndexModel0(nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - return x[torch.tensor([0])] - - torch_module = IndexModel0().eval() - - raw_data = np.random.rand(3,3).astype("float32") - - assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) - - -@tvm.testing.parametrize_targets("cuda") -def test_index_tensor1(target, dev): - class IndexModel1(nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - return x[torch.tensor([[0]])] - - torch_module = IndexModel1().eval() - - raw_data = np.random.rand(2,3).astype("float32") - - assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) - - -@tvm.testing.parametrize_targets("cuda") -def test_index_tensor2(target, dev): - class IndexTensorModel2(nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - return x[torch.tensor([0,2])] - - torch_module = IndexTensorModel2().eval() - - raw_data = np.random.rand(3,4).astype("float32") - - assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) - - -@tvm.testing.parametrize_targets("cuda") -def test_index_tensor3(target, dev): - class IndexTensorModel3(nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - return x[[[[0,2],[1,3]]]] - - torch_module = IndexTensorModel3().eval() - raw_data = np.random.rand(5,5,5).astype("float32") - assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) - - -@tvm.testing.parametrize_targets("cuda") -def test_index_tensor4(target, dev): - class IndexTensorModel4(nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - return x[[[1,4]]] - - torch_module = IndexTensorModel4().eval() - raw_data = np.random.rand(5,5,5).astype("float32") - assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) - - -@tvm.testing.parametrize_targets("cuda") -def test_index_tensor5(target, dev): - class IndexTensorModel5(nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - return x[[[[1,2,4]]]] - - torch_module = IndexTensorModel5().eval() - raw_data = np.random.rand(5,5,5).astype("float32") - assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) - - -@tvm.testing.parametrize_targets("cuda") -def test_index_tensor6(target, dev): - class IndexTensorModel6(nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - return x[[[0,1],[0,1]]] - - torch_module = IndexTensorModel6().eval() - raw_data = np.random.rand(5,5,5,5).astype("float32") - assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) - - -@tvm.testing.parametrize_targets("cuda") -def test_index_tensor7(target, dev): - class IndexTensorModel7(nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - return x[[[0,1,2,3], [1,2,3,4], [2,3,4,0]]] # both args[0] and indices are expr.Var - - torch_module = IndexTensorModel7().eval() - raw_data = np.random.rand(5,5,5,5).astype("float32") - assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) - - -@tvm.testing.parametrize_targets("cuda") -def test_index_tensor8(target, dev): - class IndexTensorModel8(nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - return x[[[[0,1],[2,3]],[[2,3],[3,4]],[[2,4],[1,2]],[[0,4],[0,3]]]] - - torch_module = IndexTensorModel8().eval() - raw_data = np.random.rand(5,5,5,5).astype("float32") - assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) - - - -if __name__ == "__main__": - tvm.testing.main() diff --git a/tests/python/relax/test_from_exported_concat.py b/tests/python/relax/test_from_exported_concat.py deleted file mode 100644 index cc64334a147c..000000000000 --- a/tests/python/relax/test_from_exported_concat.py +++ /dev/null @@ -1,86 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import tvm -from tvm import relax -import tvm.testing -import numpy as np -import torch -from torch import nn -from torch.export import export -from tvm.relax.frontend.torch import from_exported_program -from torch.nn import Softmax, Upsample - - -def assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev): - """ - This util ensures that a torch module can successfully be exported to TVM - using torch.export and that the resuling IR program gives the same result - as PyTorch when ran on CUDA. - """ - raw_data_for_tvm = raw_data.copy() # In case the data is modified - torch_data = torch.from_numpy(raw_data) - example_args = (torch_data,) - - with torch.no_grad(): - exported_program = export(torch_module, example_args) - mod_from_torch = from_exported_program(exported_program, keep_params_as_input=True) - - tvm_mod, tvm_params = relax.frontend.detach_params(mod_from_torch) - - relax_pipeline = relax.get_default_pipeline(tvm.target.Target.from_device(tvm.cuda())) - ex = relax.build(tvm_mod, target=target, relax_pipeline=relax_pipeline) - vm = relax.VirtualMachine(ex, dev) - - gpu_data = tvm.nd.array(raw_data_for_tvm, dev) - gpu_params = [tvm.nd.array(p, dev) for p in tvm_params["main"]] - gpu_out = vm["main"](gpu_data, *gpu_params) - - pytorch_out = torch_module(torch_data) - - if isinstance(pytorch_out, tuple): - for i in range(len(pytorch_out)): - actual = gpu_out[i].numpy() - desired = pytorch_out[i].detach().numpy() - np.testing.assert_allclose(actual=actual, desired=desired, rtol=1e-5, atol=1e-5) - else: - actual = gpu_out[0].numpy() - desired = pytorch_out.detach().numpy() - np.testing.assert_allclose(actual=actual, desired=desired, rtol=1e-5, atol=1e-5) - -@tvm.testing.parametrize_targets("cuda") -def test_index_tensor2(target, dev): - class ConcatFour(nn.Module): - def __init__(self, dim=0): - super(ConcatFour, self).__init__() - self.dim = dim - self.x2 = torch.randn(2, 3) - self.x3 = torch.randn(2, 3) - self.x4 = torch.randn(2, 3) - - def forward(self, x): - return torch.cat((x ,self.x2, self.x3, self.x4), dim=self.dim) - - torch_module = ConcatFour().eval() - - raw_data = np.random.rand(2,3).astype("float32") - - assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) - - -if __name__ == "__main__": - tvm.testing.main() diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index b76999d781f1..fd03a82510d8 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -754,5 +754,25 @@ def forward(self, x): assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) +@tvm.testing.parametrize_targets("cuda") +def test_concat(target, dev): + class ConcatFour(nn.Module): + def __init__(self, dim=0): + super(ConcatFour, self).__init__() + self.dim = dim + self.x2 = torch.randn(2, 3) + self.x3 = torch.randn(2, 3) + self.x4 = torch.randn(2, 3) + + def forward(self, x): + return torch.cat((x ,self.x2, self.x3, self.x4), dim=self.dim) + + torch_module = ConcatFour().eval() + + raw_data = np.random.rand(2,3).astype("float32") + + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/test_from_exported_to_cuda_NEW.py b/tests/python/relax/test_from_exported_to_cuda_NEW.py deleted file mode 100644 index 8744af9a82f7..000000000000 --- a/tests/python/relax/test_from_exported_to_cuda_NEW.py +++ /dev/null @@ -1,67 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import tvm -from tvm import relax -import tvm.testing -import numpy as np -import torch -from torch import nn -from torch.export import export -from tvm.relax.frontend.torch import from_exported_program -from torch.nn import Softmax, Upsample - - -def assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev): - """ - This util ensures that a torch module can successfully be exported to TVM - using torch.export and that the resuling IR program gives the same result - as PyTorch when ran on CUDA. - """ - raw_data_for_tvm = raw_data.copy() # In case the data is modified - torch_data = torch.from_numpy(raw_data) - example_args = (torch_data,) - - with torch.no_grad(): - exported_program = export(torch_module, example_args) - mod_from_torch = from_exported_program(exported_program, keep_params_as_input=True) - - tvm_mod, tvm_params = relax.frontend.detach_params(mod_from_torch) - - relax_pipeline = relax.get_default_pipeline(tvm.target.Target.from_device(tvm.cuda())) - ex = relax.build(tvm_mod, target=target, relax_pipeline=relax_pipeline) - vm = relax.VirtualMachine(ex, dev) - - gpu_data = tvm.nd.array(raw_data_for_tvm, dev) - gpu_params = [tvm.nd.array(p, dev) for p in tvm_params["main"]] - gpu_out = vm["main"](gpu_data, *gpu_params) - - pytorch_out = torch_module(torch_data) - - if isinstance(pytorch_out, tuple): - for i in range(len(pytorch_out)): - actual = gpu_out[i].numpy() - desired = pytorch_out[i].detach().numpy() - np.testing.assert_allclose(actual=actual, desired=desired, rtol=1e-5, atol=1e-5) - else: - actual = gpu_out[0].numpy() - desired = pytorch_out.detach().numpy() - np.testing.assert_allclose(actual=actual, desired=desired, rtol=1e-5, atol=1e-5) - - -if __name__ == "__main__": - tvm.testing.main() From 7dd9e8fe6052030c185718e974d66fbf01844bfc Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Mon, 14 Apr 2025 13:30:29 -0400 Subject: [PATCH 088/105] removed axis, all tests pass, doc TODOs remain --- python/tvm/relax/op/manipulate.py | 4 ++-- python/tvm/relax/transform/legalize_ops/manipulate.py | 4 ++-- python/tvm/topi/transform.py | 2 +- src/relax/op/tensor/manipulate.cc | 8 ++------ src/relax/op/tensor/manipulate.h | 2 +- 5 files changed, 8 insertions(+), 12 deletions(-) diff --git a/python/tvm/relax/op/manipulate.py b/python/tvm/relax/op/manipulate.py index ef3a9b653a47..254e2c8b624d 100644 --- a/python/tvm/relax/op/manipulate.py +++ b/python/tvm/relax/op/manipulate.py @@ -507,7 +507,7 @@ def gather_nd(data: Expr, indices: Expr, batch_dims: int = 0) -> Expr: # TODO change names of args and remove axis arg -def index_tensor(data:Expr, indices: Union[Expr, List[Expr]], axis: Optional[int] = 0) -> Expr: +def index_tensor(data:Expr, indices: Union[Expr, List[Expr]]) -> Expr: """Concatenate the input tensors along the given axis. Parameters @@ -527,7 +527,7 @@ def index_tensor(data:Expr, indices: Union[Expr, List[Expr]], axis: Optional[int """ if isinstance(indices, (list, tuple)): indices = RxTuple(indices) - return _ffi_api.index_tensor(data, indices, axis) # type: ignore + return _ffi_api.index_tensor(data, indices) # type: ignore def scatter_elements( data: Expr, indices: Expr, updates: Expr, axis: int = 0, reduction: str = "update" diff --git a/python/tvm/relax/transform/legalize_ops/manipulate.py b/python/tvm/relax/transform/legalize_ops/manipulate.py index 6aa150183ff5..b1bbaa4db28d 100644 --- a/python/tvm/relax/transform/legalize_ops/manipulate.py +++ b/python/tvm/relax/transform/legalize_ops/manipulate.py @@ -178,8 +178,8 @@ def _index_tensor(bb: BlockBuilder, call: Call) -> Expr: fields = ( t.fields if isinstance(t, Tuple) else [bb.emit(TupleGetItem(t, i)) for i in range(n_field)] ) - return bb.call_te( # TODO remove axis - topi.index_tensor, call.args[0], fields, None if call.attrs.axis is None else call.attrs.axis.value + return bb.call_te( + topi.index_tensor, call.args[0], fields ) @register_legalize("relax.scatter_elements") diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py index a37210b552a9..f5a0809493d9 100644 --- a/python/tvm/topi/transform.py +++ b/python/tvm/topi/transform.py @@ -1053,7 +1053,7 @@ def _apply_trilu(*indices): return te.compute(data.shape, _apply_trilu, name="trilu", tag=topi.tag.ELEMWISE) -def index_tensor(data, indices, axis): # TODO remove axis argument +def index_tensor(data, indices): # TODO remove axis argument """ TODO docstring """ diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index 786543a8dfb2..e17d3dcd4574 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -477,12 +477,9 @@ TVM_REGISTER_OP("relax.flatten") /* relax.index_tensor */ -Expr index_tensor(Expr first, Expr tensors, Optional axis) { - ObjectPtr attrs = make_object(); // TODO remove this - attrs->axis = std::move(axis); - +Expr index_tensor(Expr first, Expr tensors) { static const Op& op = Op::Get("relax.index_tensor"); - return Call(op, {std::move(first), std::move(tensors)}, Attrs(attrs), {}); + return Call(op, {std::move(first), std::move(tensors)}, Attrs(), {}); } TVM_REGISTER_GLOBAL("relax.op.index_tensor").set_body_typed(index_tensor); @@ -513,7 +510,6 @@ StructInfo InferStructInfoIndexTensor(const Call& call, const BlockBuilder& ctx) } TVM_REGISTER_OP("relax.index_tensor") - .set_attrs_type() // TODO remove that .set_num_inputs(2) .add_argument("data", "Tensor", "The input data.") .add_argument("indices", "List of Tensors", "The indices used to index.") diff --git a/src/relax/op/tensor/manipulate.h b/src/relax/op/tensor/manipulate.h index 7aa641d6b971..2fd2ed225d69 100644 --- a/src/relax/op/tensor/manipulate.h +++ b/src/relax/op/tensor/manipulate.h @@ -211,7 +211,7 @@ Expr gather_nd(Expr data, Expr indices, int batch_dims = 0); * The output shape is batch_dims + indices.shape[:-1] + data.shape[batch_dims + * indices.shape[-1]:] */ -Expr index_tensor(Expr data, Expr indices, Optional axis); // TODO remove axis +Expr index_tensor(Expr data, Expr indices); /*! * \brief Scatter updates into an array according to indices. From 2f2505f4bfd296f1d5b82acedad9be376e0925e3 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Mon, 14 Apr 2025 13:42:06 -0400 Subject: [PATCH 089/105] resolve conflict --- python/tvm/relax/frontend/torch/exported_program_translator.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 3f6ba3bb951f..aff99aa4e599 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -410,11 +410,8 @@ def create_convert_map( "flatten.using_ints": self._flatten, "flip.default": self._flip, "gather.default": self._gather, -<<<<<<< HEAD "index.Tensor": self._index_tensor, -======= "narrow.default": self._narrow, ->>>>>>> upstream/main "permute.default": self._permute, "repeat.default": self._repeat, "select.int": self._select, From 41bf141341e173bcc1780fe43d946cd3ac4e47a8 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Mon, 14 Apr 2025 13:47:46 -0400 Subject: [PATCH 090/105] linting --- python/tvm/relax/op/manipulate.py | 5 +- .../transform/legalize_ops/manipulate.py | 8 +-- python/tvm/script/ir_builder/relax/ir.py | 4 +- python/tvm/topi/transform.py | 7 +- src/relax/op/tensor/manipulate.h | 2 +- .../relax/test_from_exported_to_cuda.py | 69 ++++++++++--------- 6 files changed, 49 insertions(+), 46 deletions(-) diff --git a/python/tvm/relax/op/manipulate.py b/python/tvm/relax/op/manipulate.py index 254e2c8b624d..58a2d686f2bb 100644 --- a/python/tvm/relax/op/manipulate.py +++ b/python/tvm/relax/op/manipulate.py @@ -70,6 +70,7 @@ def concat(tensors: Union[Expr, List[Expr]], axis: Optional[int] = 0) -> Expr: tensors = RxTuple(tensors) return _ffi_api.concat(tensors, axis) # type: ignore + def expand_dims(x: Expr, axis: Union[int, List[int]]) -> Expr: """Insert new axes at the positions given by `axis`. @@ -298,6 +299,7 @@ def collapse_sum_like(data: Expr, collapse_target: Expr) -> Expr: """ return _ffi_api.collapse_sum_like(data, collapse_target) # type: ignore + def collapse_sum_to(data: Expr, shape: Union[Tuple[PrimExprLike], Expr]) -> Expr: """Return a summation of data to the given shape. @@ -507,7 +509,7 @@ def gather_nd(data: Expr, indices: Expr, batch_dims: int = 0) -> Expr: # TODO change names of args and remove axis arg -def index_tensor(data:Expr, indices: Union[Expr, List[Expr]]) -> Expr: +def index_tensor(data: Expr, indices: Union[Expr, List[Expr]]) -> Expr: """Concatenate the input tensors along the given axis. Parameters @@ -529,6 +531,7 @@ def index_tensor(data:Expr, indices: Union[Expr, List[Expr]]) -> Expr: indices = RxTuple(indices) return _ffi_api.index_tensor(data, indices) # type: ignore + def scatter_elements( data: Expr, indices: Expr, updates: Expr, axis: int = 0, reduction: str = "update" ): diff --git a/python/tvm/relax/transform/legalize_ops/manipulate.py b/python/tvm/relax/transform/legalize_ops/manipulate.py index b1bbaa4db28d..68ed9d423a92 100644 --- a/python/tvm/relax/transform/legalize_ops/manipulate.py +++ b/python/tvm/relax/transform/legalize_ops/manipulate.py @@ -45,7 +45,7 @@ def reshape_call_te(bb: BlockBuilder, call: Call): register_legalize("relax.broadcast_to", _reshape(topi.broadcast_to, "broadcast_to")) register_legalize("relax.reshape", _reshape(topi.reshape, "reshape")) -register_legalize( +register_legalize( "relax.collapse_sum_like", _reshape(topi.collapse_sum, "collapse_sum", is_collapse_sum_like=True), ) @@ -72,7 +72,6 @@ def _concat(bb: BlockBuilder, call: Call) -> Expr: ) - @register_legalize("relax.expand_dims") def _expand_dims(bb: BlockBuilder, call: Call) -> Expr: def te_expand_dims(data, axis): @@ -178,9 +177,8 @@ def _index_tensor(bb: BlockBuilder, call: Call) -> Expr: fields = ( t.fields if isinstance(t, Tuple) else [bb.emit(TupleGetItem(t, i)) for i in range(n_field)] ) - return bb.call_te( - topi.index_tensor, call.args[0], fields - ) + return bb.call_te(topi.index_tensor, call.args[0], fields) + @register_legalize("relax.scatter_elements") def _scatter_elements(bb: BlockBuilder, call: Call) -> Expr: diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index e0ead3098d5a..e1687a6d5343 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -101,7 +101,7 @@ greater_equal, hint_on_device, image, - index_tensor, # TODO do something with this or remove? + index_tensor, # TODO do something with this or remove? invoke_closure, invoke_pure_closure, isfinite, @@ -784,7 +784,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "hexagon", "hint_on_device", "image", - "index_tensor", # TODO keep or remove? + "index_tensor", # TODO keep or remove? "invoke_closure", "invoke_pure_closure", "isfinite", diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py index f5a0809493d9..c2b826f43a8e 100644 --- a/python/tvm/topi/transform.py +++ b/python/tvm/topi/transform.py @@ -1053,9 +1053,8 @@ def _apply_trilu(*indices): return te.compute(data.shape, _apply_trilu, name="trilu", tag=topi.tag.ELEMWISE) -def index_tensor(data, indices): # TODO remove axis argument - """ TODO docstring - """ - return topi.adv_index(data, indices) +def index_tensor(data, indices): # TODO remove axis argument + """TODO docstring""" + return topi.adv_index(data, indices) diff --git a/src/relax/op/tensor/manipulate.h b/src/relax/op/tensor/manipulate.h index 2fd2ed225d69..0c7d482ffacb 100644 --- a/src/relax/op/tensor/manipulate.h +++ b/src/relax/op/tensor/manipulate.h @@ -211,7 +211,7 @@ Expr gather_nd(Expr data, Expr indices, int batch_dims = 0); * The output shape is batch_dims + indices.shape[:-1] + data.shape[batch_dims + * indices.shape[-1]:] */ -Expr index_tensor(Expr data, Expr indices); +Expr index_tensor(Expr data, Expr indices); /*! * \brief Scatter updates into an array according to indices. diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index 2f9489327a31..da4197d9ce39 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -63,7 +63,6 @@ def assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, tar np.testing.assert_allclose(actual=actual, desired=desired, rtol=1e-5, atol=1e-5) - @tvm.testing.parametrize_targets("cuda") def test_full(target, dev): class FullModel(nn.Module): @@ -72,16 +71,17 @@ def __init__(self): def forward(self, x): return torch.full((2, 3), 3.141592) - + torch_module = FullModel().eval() - raw_data = np.random.rand(3,3).astype("float32") + raw_data = np.random.rand(3, 3).astype("float32") assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) # Test index.Tensor # TODO aggregate into one big tet + @tvm.testing.parametrize_targets("cuda") def test_index_tensor0(target, dev): class IndexModel0(nn.Module): @@ -90,10 +90,10 @@ def __init__(self): def forward(self, x): return x[torch.tensor([0])] - + torch_module = IndexModel0().eval() - raw_data = np.random.rand(3,3).astype("float32") + raw_data = np.random.rand(3, 3).astype("float32") assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) @@ -106,10 +106,10 @@ def __init__(self): def forward(self, x): return x[torch.tensor([[0]])] - + torch_module = IndexModel1().eval() - raw_data = np.random.rand(2,3).astype("float32") + raw_data = np.random.rand(2, 3).astype("float32") assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) @@ -121,11 +121,11 @@ def __init__(self): super().__init__() def forward(self, x): - return x[torch.tensor([0,2])] - + return x[torch.tensor([0, 2])] + torch_module = IndexTensorModel2().eval() - raw_data = np.random.rand(3,4).astype("float32") + raw_data = np.random.rand(3, 4).astype("float32") assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) @@ -137,10 +137,10 @@ def __init__(self): super().__init__() def forward(self, x): - return x[[[[0,2],[1,3]]]] - + return x[[[[0, 2], [1, 3]]]] + torch_module = IndexTensorModel3().eval() - raw_data = np.random.rand(5,5,5).astype("float32") + raw_data = np.random.rand(5, 5, 5).astype("float32") assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) @@ -151,10 +151,10 @@ def __init__(self): super().__init__() def forward(self, x): - return x[[[1,4]]] - + return x[[[1, 4]]] + torch_module = IndexTensorModel4().eval() - raw_data = np.random.rand(5,5,5).astype("float32") + raw_data = np.random.rand(5, 5, 5).astype("float32") assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) @@ -165,10 +165,10 @@ def __init__(self): super().__init__() def forward(self, x): - return x[[[[1,2,4]]]] - + return x[[[[1, 2, 4]]]] + torch_module = IndexTensorModel5().eval() - raw_data = np.random.rand(5,5,5).astype("float32") + raw_data = np.random.rand(5, 5, 5).astype("float32") assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) @@ -179,10 +179,10 @@ def __init__(self): super().__init__() def forward(self, x): - return x[[[0,1],[0,1]]] - + return x[[[0, 1], [0, 1]]] + torch_module = IndexTensorModel6().eval() - raw_data = np.random.rand(5,5,5,5).astype("float32") + raw_data = np.random.rand(5, 5, 5, 5).astype("float32") assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) @@ -193,10 +193,12 @@ def __init__(self): super().__init__() def forward(self, x): - return x[[[0,1,2,3], [1,2,3,4], [2,3,4,0]]] # both args[0] and indices are expr.Var - + return x[ + [[0, 1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 0]] + ] # both args[0] and indices are expr.Var + torch_module = IndexTensorModel7().eval() - raw_data = np.random.rand(5,5,5,5).astype("float32") + raw_data = np.random.rand(5, 5, 5, 5).astype("float32") assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) @@ -207,10 +209,10 @@ def __init__(self): super().__init__() def forward(self, x): - return x[[[[0,1],[2,3]],[[2,3],[3,4]],[[2,4],[1,2]],[[0,4],[0,3]]]] - + return x[[[[0, 1], [2, 3]], [[2, 3], [3, 4]], [[2, 4], [1, 2]], [[0, 4], [0, 3]]]] + torch_module = IndexTensorModel8().eval() - raw_data = np.random.rand(5,5,5,5).astype("float32") + raw_data = np.random.rand(5, 5, 5, 5).astype("float32") assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) @@ -786,6 +788,7 @@ def forward(self, x): raw_data = np.random.rand(10, 10, 10).astype("float32") assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + @tvm.testing.parametrize_targets("cuda") def test_mul(target, dev): class MulModule(nn.Module): @@ -794,11 +797,11 @@ def __init__(self): self.y = torch.tensor(np.random.rand(2, 3).astype("float32")) def forward(self, x): - return x.mul(self.y) + return x.mul(self.y) torch_module = MulModule().eval() raw_data = np.random.rand(2, 3).astype("float32") - + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) @@ -813,11 +816,11 @@ def __init__(self, dim=0): self.x4 = torch.randn(2, 3) def forward(self, x): - return torch.cat((x ,self.x2, self.x3, self.x4), dim=self.dim) - + return torch.cat((x, self.x2, self.x3, self.x4), dim=self.dim) + torch_module = ConcatFour().eval() - raw_data = np.random.rand(2,3).astype("float32") + raw_data = np.random.rand(2, 3).astype("float32") assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) From 01b23f02d2dd1fd756e2f2ca658144275847ba77 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Wed, 16 Apr 2025 19:22:09 -0400 Subject: [PATCH 091/105] pass correctness with first indices shape --- src/relax/op/tensor/manipulate.cc | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index e17d3dcd4574..168d849ddf95 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -489,23 +489,27 @@ StructInfo InferStructInfoIndexTensor(const Call& call, const BlockBuilder& ctx) ctx->ReportFatal(Diagnostic::Error(call) << "Index.Tensor op should have 2 arguments"); } TensorStructInfo data_sinfo = GetInputTensorStructInfo(call, 0, ctx); + DataType output_dtype = data_sinfo->dtype; - Array tensor_sinfo = GetTensorStructInfoFromTuple(call, ctx, call->args[1]); - if (tensor_sinfo.empty()) { + Array indices_sinfo = GetTensorStructInfoFromTuple(call, ctx, call->args[1]); + if (indices_sinfo.empty()) { ctx->ReportFatal(Diagnostic::Error(call) << "Index.Tensor expects at least one tensor in the input Tuple. However, the " "given input Tuple is empty."); // TODO is this always true? } + TensorStructInfo first_indices_sinfo = indices_sinfo[0]; - DataType output_dtype = data_sinfo->dtype; - - TensorStructInfo indices_sinfo = data_sinfo; - - if (indices_sinfo->shape.defined()) { // TODO need this condition, but not sure why! Isn't that - // not reflective of the output anyway? - return TensorStructInfo(indices_sinfo->shape.value(), output_dtype, indices_sinfo->vdevice); + if (first_indices_sinfo->shape.defined()) { + LOG(INFO) << "USUALLY HERE " + "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA" + "AAAAAAAAAAAAAAAAAAAAAAA"; + return TensorStructInfo(first_indices_sinfo->shape.value(), output_dtype, data_sinfo->vdevice); } else { - return TensorStructInfo(output_dtype, indices_sinfo->ndim, indices_sinfo->vdevice); + LOG(INFO) << " NOT HERE " + "BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB" + "BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB"; + + return TensorStructInfo(output_dtype, data_sinfo->ndim, data_sinfo->vdevice); } } From 982719f89c00f03f2f0363d106573c70a0bacb4c Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Wed, 16 Apr 2025 19:28:45 -0400 Subject: [PATCH 092/105] corretness passes with indices shape --- src/relax/op/tensor/manipulate.cc | 21 +++++++-------------- 1 file changed, 7 insertions(+), 14 deletions(-) diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index 168d849ddf95..169d730de7da 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -493,22 +493,15 @@ StructInfo InferStructInfoIndexTensor(const Call& call, const BlockBuilder& ctx) Array indices_sinfo = GetTensorStructInfoFromTuple(call, ctx, call->args[1]); if (indices_sinfo.empty()) { - ctx->ReportFatal(Diagnostic::Error(call) - << "Index.Tensor expects at least one tensor in the input Tuple. However, the " - "given input Tuple is empty."); // TODO is this always true? + ctx->ReportFatal( + Diagnostic::Error(call) + << "Index.Tensor expects at least one tensor in the indices Tuple. However, the " + "given input Tuple is empty."); // TODO is this always true? } - TensorStructInfo first_indices_sinfo = indices_sinfo[0]; - if (first_indices_sinfo->shape.defined()) { - LOG(INFO) << "USUALLY HERE " - "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA" - "AAAAAAAAAAAAAAAAAAAAAAA"; - return TensorStructInfo(first_indices_sinfo->shape.value(), output_dtype, data_sinfo->vdevice); + if (indices_sinfo[0]->shape.defined()) { + return TensorStructInfo(indices_sinfo[0]->shape.value(), output_dtype, data_sinfo->vdevice); } else { - LOG(INFO) << " NOT HERE " - "BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB" - "BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB"; - return TensorStructInfo(output_dtype, data_sinfo->ndim, data_sinfo->vdevice); } } @@ -517,7 +510,7 @@ TVM_REGISTER_OP("relax.index_tensor") .set_num_inputs(2) .add_argument("data", "Tensor", "The input data.") .add_argument("indices", "List of Tensors", "The indices used to index.") - .set_attr("FInferStructInfo", InferStructInfoIndexTensor) // TODO necessary + .set_attr("FInferStructInfo", InferStructInfoIndexTensor) .set_attr("FPurity", Bool(true)); /* relax.layout_transform */ From b12484627a6462bf7b799d82b7cf245f7c959673 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Wed, 16 Apr 2025 20:47:50 -0400 Subject: [PATCH 093/105] correctness checks pass! --- src/relax/op/tensor/manipulate.cc | 144 +++++++++++++++++++++++++++--- 1 file changed, 134 insertions(+), 10 deletions(-) diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index 169d730de7da..d07394a4596e 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -488,22 +488,146 @@ StructInfo InferStructInfoIndexTensor(const Call& call, const BlockBuilder& ctx) if (call->args.size() != 2) { ctx->ReportFatal(Diagnostic::Error(call) << "Index.Tensor op should have 2 arguments"); } - TensorStructInfo data_sinfo = GetInputTensorStructInfo(call, 0, ctx); - DataType output_dtype = data_sinfo->dtype; + TensorStructInfo data_sinfo = GetInputTensorStructInfo(call, 0, ctx); Array indices_sinfo = GetTensorStructInfoFromTuple(call, ctx, call->args[1]); + if (indices_sinfo.empty()) { - ctx->ReportFatal( - Diagnostic::Error(call) - << "Index.Tensor expects at least one tensor in the indices Tuple. However, the " - "given input Tuple is empty."); // TODO is this always true? + ctx->ReportFatal(Diagnostic::Error(call) + << "index_tensor expects a non‑empty tuple of index tensors"); } - if (indices_sinfo[0]->shape.defined()) { - return TensorStructInfo(indices_sinfo[0]->shape.value(), output_dtype, data_sinfo->vdevice); - } else { - return TensorStructInfo(output_dtype, data_sinfo->ndim, data_sinfo->vdevice); + DataType output_dtype = data_sinfo->dtype; + int n_indices = static_cast(indices_sinfo.size()); + Optional vdev = data_sinfo->vdevice; + + /* ------------------------------------------------------------------ * + * 2. Sanity‑checks: * + * • index dtype must be (u)int * + * • #indices ≤ data.ndim (if data.ndim known) * + * ------------------------------------------------------------------ */ + for (int i = 0; i < n_indices; ++i) { + const auto& s = indices_sinfo[i]; + if (!s->IsUnknownDtype() && !s->dtype.is_int()) { + ctx->ReportFatal(Diagnostic::Error(call) + << "index_tensor requires every index tensor to have an integer dtype; " + << "index #" << i << " has dtype " << s->dtype); + } + } + if (!data_sinfo->IsUnknownNdim() && n_indices > data_sinfo->ndim) { + ctx->ReportFatal(Diagnostic::Error(call) + << "index_tensor received " << n_indices + << " index tensors, but data has only " << data_sinfo->ndim << " dimensions"); + } + + /* ------------------------------------------------------------------ * + * 3. Collect shape values of indices (when fully known). * + * We do *best‑effort* broadcasting analysis: * + * ‑ if all index shapes are ShapeExpr, try to compute * + * the common broadcast shape exactly; * + * ‑ otherwise fall back to “unknown shape”. * + * ------------------------------------------------------------------ */ + arith::Analyzer* analyzer = ctx->GetAnalyzer(); + bool all_index_have_shape_value = true; + std::vector> index_shapes; // only filled when fully known + int max_index_ndim = 0; + + for (const auto& s : indices_sinfo) { + const auto* shp = s->shape.as(); + if (!shp) { + all_index_have_shape_value = false; + } else { + index_shapes.push_back(shp->values); + max_index_ndim = std::max(max_index_ndim, static_cast(shp->values.size())); + } + if (!s->IsUnknownNdim()) { + max_index_ndim = std::max(max_index_ndim, s->ndim); + } + } + + Optional> broadcast_shape; // the `B` in OutShape = B + data[k:] + bool shape_unknown = !all_index_have_shape_value; + + if (all_index_have_shape_value) { + // --- pair‑wise numpy‑style broadcasting -------------------------------- + + // initialise broadcast result with 1’s + Array out_shape; + for (int i = 0; i < max_index_ndim; ++i) { + out_shape.push_back(IntImm(DataType::Int(64), 1)); + } + + for (const auto& ishape : index_shapes) { + int cur_ndim = ishape.size(); + for (int axis = 0; axis < max_index_ndim; ++axis) { + int lhs_axis = max_index_ndim - 1 - axis; // aligned from right + int rhs_axis = cur_ndim - 1 - axis; + if (rhs_axis < 0) break; // shorter rank – done + + PrimExpr lhs_dim = out_shape[lhs_axis]; + PrimExpr rhs_dim = ishape[rhs_axis]; + + const auto* lhs_int = lhs_dim.as(); + const auto* rhs_int = rhs_dim.as(); + + // Case 1: current broadcast slot is 1 → always replace + if (lhs_int && lhs_int->value == 1) { + out_shape.Set(lhs_axis, rhs_dim); + continue; + } + // Case 2: rhs is 1 → keep lhs_dim unchanged + if (rhs_int && rhs_int->value == 1) { + continue; + } + // Both are non‑one constants: must equal + if (lhs_int && rhs_int && lhs_int->value != rhs_int->value) { + ctx->ReportFatal(Diagnostic::Error(call) + << "index_tensor: cannot broadcast index shapes; mismatch at axis " + << lhs_axis << ": " << lhs_dim << " vs " << rhs_dim); + } + // Otherwise (symbolics) – require provably equal, else give up + if (!analyzer->CanProveEqual(lhs_dim, rhs_dim)) { + shape_unknown = true; + break; + } + } + if (shape_unknown) break; + } + + if (!shape_unknown) broadcast_shape = out_shape; + } + + /* ------------------------------------------------------------------ * + * 4. Derive output ndim (= |B| + (data.ndim – k)) when possible. * + * ------------------------------------------------------------------ */ + int out_ndim = kUnknownNDim; + if (!data_sinfo->IsUnknownNdim()) { + int tail_ndim = data_sinfo->ndim - n_indices; + if (broadcast_shape.defined()) { + out_ndim = static_cast(broadcast_shape.value().size()) + tail_ndim; + } else if (!shape_unknown) { + out_ndim = max_index_ndim + tail_ndim; + } } + + /* ------------------------------------------------------------------ * + * 5. Construct final explicit output shape when fully known. * + * ------------------------------------------------------------------ */ + if (broadcast_shape.defined()) { + const auto* data_shape_expr = data_sinfo->shape.as(); + if (data_shape_expr) { + Array result_shape = broadcast_shape.value(); + for (int i = n_indices; i < data_sinfo->ndim; ++i) { + result_shape.push_back(data_shape_expr->values[i]); + } + return TensorStructInfo(ShapeExpr(result_shape), output_dtype, vdev); + } + } + + /* ------------------------------------------------------------------ * + * 6. Fallback: known rank only, or completely unknown. * + * ------------------------------------------------------------------ */ + return TensorStructInfo(output_dtype, out_ndim, vdev); } TVM_REGISTER_OP("relax.index_tensor") From f1737b2705845d6d671978c59c2a8e6ef4bf1f03 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Wed, 16 Apr 2025 21:02:19 -0400 Subject: [PATCH 094/105] cleanup - still passes correctness --- src/relax/op/tensor/manipulate.cc | 43 ++++++++++--------------------- 1 file changed, 13 insertions(+), 30 deletions(-) diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index d07394a4596e..403e4bf438fe 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -501,35 +501,26 @@ StructInfo InferStructInfoIndexTensor(const Call& call, const BlockBuilder& ctx) int n_indices = static_cast(indices_sinfo.size()); Optional vdev = data_sinfo->vdevice; - /* ------------------------------------------------------------------ * - * 2. Sanity‑checks: * - * • index dtype must be (u)int * - * • #indices ≤ data.ndim (if data.ndim known) * - * ------------------------------------------------------------------ */ + // Indices must be integers for (int i = 0; i < n_indices; ++i) { const auto& s = indices_sinfo[i]; if (!s->IsUnknownDtype() && !s->dtype.is_int()) { ctx->ReportFatal(Diagnostic::Error(call) << "index_tensor requires every index tensor to have an integer dtype; " - << "index #" << i << " has dtype " << s->dtype); + << "index " << i << " has dtype " << s->dtype); } } + + // Count of indices must be less than or equal to data.ndim if (!data_sinfo->IsUnknownNdim() && n_indices > data_sinfo->ndim) { ctx->ReportFatal(Diagnostic::Error(call) << "index_tensor received " << n_indices << " index tensors, but data has only " << data_sinfo->ndim << " dimensions"); } - /* ------------------------------------------------------------------ * - * 3. Collect shape values of indices (when fully known). * - * We do *best‑effort* broadcasting analysis: * - * ‑ if all index shapes are ShapeExpr, try to compute * - * the common broadcast shape exactly; * - * ‑ otherwise fall back to “unknown shape”. * - * ------------------------------------------------------------------ */ arith::Analyzer* analyzer = ctx->GetAnalyzer(); bool all_index_have_shape_value = true; - std::vector> index_shapes; // only filled when fully known + std::vector> index_shapes; int max_index_ndim = 0; for (const auto& s : indices_sinfo) { @@ -545,12 +536,10 @@ StructInfo InferStructInfoIndexTensor(const Call& call, const BlockBuilder& ctx) } } - Optional> broadcast_shape; // the `B` in OutShape = B + data[k:] + Optional> broadcast_shape; bool shape_unknown = !all_index_have_shape_value; if (all_index_have_shape_value) { - // --- pair‑wise numpy‑style broadcasting -------------------------------- - // initialise broadcast result with 1’s Array out_shape; for (int i = 0; i < max_index_ndim; ++i) { @@ -570,22 +559,22 @@ StructInfo InferStructInfoIndexTensor(const Call& call, const BlockBuilder& ctx) const auto* lhs_int = lhs_dim.as(); const auto* rhs_int = rhs_dim.as(); - // Case 1: current broadcast slot is 1 → always replace + // Case 1: current broadcast slot is 1 -> always replace if (lhs_int && lhs_int->value == 1) { out_shape.Set(lhs_axis, rhs_dim); continue; } - // Case 2: rhs is 1 → keep lhs_dim unchanged + // Case 2: rhs is 1 -> keep lhs_dim unchanged if (rhs_int && rhs_int->value == 1) { continue; } // Both are non‑one constants: must equal if (lhs_int && rhs_int && lhs_int->value != rhs_int->value) { ctx->ReportFatal(Diagnostic::Error(call) - << "index_tensor: cannot broadcast index shapes; mismatch at axis " + << "index_tensor: cannot broadcast index shapes. Mismatch at axis " << lhs_axis << ": " << lhs_dim << " vs " << rhs_dim); } - // Otherwise (symbolics) – require provably equal, else give up + // Give up if not provablt equal if (!analyzer->CanProveEqual(lhs_dim, rhs_dim)) { shape_unknown = true; break; @@ -597,9 +586,7 @@ StructInfo InferStructInfoIndexTensor(const Call& call, const BlockBuilder& ctx) if (!shape_unknown) broadcast_shape = out_shape; } - /* ------------------------------------------------------------------ * - * 4. Derive output ndim (= |B| + (data.ndim – k)) when possible. * - * ------------------------------------------------------------------ */ + // Count of dimensions in output int out_ndim = kUnknownNDim; if (!data_sinfo->IsUnknownNdim()) { int tail_ndim = data_sinfo->ndim - n_indices; @@ -610,9 +597,7 @@ StructInfo InferStructInfoIndexTensor(const Call& call, const BlockBuilder& ctx) } } - /* ------------------------------------------------------------------ * - * 5. Construct final explicit output shape when fully known. * - * ------------------------------------------------------------------ */ + // Derive output shape if (broadcast_shape.defined()) { const auto* data_shape_expr = data_sinfo->shape.as(); if (data_shape_expr) { @@ -624,9 +609,7 @@ StructInfo InferStructInfoIndexTensor(const Call& call, const BlockBuilder& ctx) } } - /* ------------------------------------------------------------------ * - * 6. Fallback: known rank only, or completely unknown. * - * ------------------------------------------------------------------ */ + // Unknown output shape return TensorStructInfo(output_dtype, out_ndim, vdev); } From 47adecf61b52bea86c57f0f7a3f619b4b837e650 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Wed, 16 Apr 2025 21:22:23 -0400 Subject: [PATCH 095/105] comments --- python/tvm/relax/op/manipulate.py | 57 +++++++++++++++++++----- python/tvm/script/ir_builder/relax/ir.py | 4 +- python/tvm/topi/transform.py | 48 +++++++++++++++++++- src/relax/op/tensor/manipulate.h | 15 +++---- 4 files changed, 102 insertions(+), 22 deletions(-) diff --git a/python/tvm/relax/op/manipulate.py b/python/tvm/relax/op/manipulate.py index 58a2d686f2bb..1eb90cfdcb8a 100644 --- a/python/tvm/relax/op/manipulate.py +++ b/python/tvm/relax/op/manipulate.py @@ -508,24 +508,61 @@ def gather_nd(data: Expr, indices: Expr, batch_dims: int = 0) -> Expr: return _ffi_api.gather_nd(data, indices, batch_dims) # type: ignore -# TODO change names of args and remove axis arg def index_tensor(data: Expr, indices: Union[Expr, List[Expr]]) -> Expr: - """Concatenate the input tensors along the given axis. + """Advanced‑tensor indexing (NumPy/PyTorch‐style). + + Given k index tensors ``indices = (I0, I1, …, Ik‑1)`` this + operator selects elements from ``data`` as if one had written + ``data[I0, I1, …, Ik‑1]`` in NumPy/PyTorch: + + * All index tensors must have an integer dtype. + * Their shapes are broadcast together to a common shape ``B`` in + the usual NumPy way. + * The result shape is ``B + data.shape[k:]`` (i.e. the broadcast + shape followed by the remaining axes of ``data`` that are *not* + indexed). + * At compile‑time Relax checks + * the number of index tensors ``k`` does not exceed + ``data.ndim``, + * the dtypes are integer, + * the shapes are consitent (broadcast‑compatible). Parameters ---------- - tensors : Union[relax.Expr, List[relax.Expr]] - An Expr in Tuple type, containing the tensors to be concatenated, - or a list of Tensors. + data : relax.Expr + The input tensor to be indexed. - axis : Optional[int] - The axis along which the tensors are concatenated. - If `axis` is `None`, the input tensor is required to be flattened before concatenation. + indices : Union[relax.Expr, List[relax.Expr]] + A Tuple expression containing the index tensors, + or a Python ``list`` / ``tuple`` that will be promoted to a + tuple expression automatically. Each tensor must have an + integer dtype. Returns ------- - result: relax.Expr - The concatenated tensor. + result : relax.Expr + The tensor obtained after advanced indexing. Its dtype equals + ``data.dtype`` + + Examples + -------- + .. code-block:: python + + import numpy as np + import tvm.relax as R + + x = R.const(np.arange(9).reshape(3, 3).astype("float32")) + row = R.const(np.array([0, 2])) # shape (2,) + col = R.const(np.array([1, 0])) # shape (2,) + + y = R.index_tensor(x, [row, col]) + # y.shape == (2,) ; y == [1., 6.] + + # Broadcasting: row : (2,1), col : (1,3) → B = (2,3) + row = R.const(np.array([[0],[1]])) + col = R.const(np.array([[0,1,2]])) + z = R.index_tensor(x, [row, col]) + # z.shape == (2,3) """ if isinstance(indices, (list, tuple)): indices = RxTuple(indices) diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index e1687a6d5343..18aa87855908 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -101,7 +101,7 @@ greater_equal, hint_on_device, image, - index_tensor, # TODO do something with this or remove? + index_tensor, invoke_closure, invoke_pure_closure, isfinite, @@ -784,7 +784,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "hexagon", "hint_on_device", "image", - "index_tensor", # TODO keep or remove? + "index_tensor", "invoke_closure", "invoke_pure_closure", "isfinite", diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py index c2b826f43a8e..d80a5b3438fd 100644 --- a/python/tvm/topi/transform.py +++ b/python/tvm/topi/transform.py @@ -1053,8 +1053,52 @@ def _apply_trilu(*indices): return te.compute(data.shape, _apply_trilu, name="trilu", tag=topi.tag.ELEMWISE) +def index_tensor(data, indices): + """Advanced‑tensor indexing (NumPy/PyTorch‐style). + + Given k index tensors ``indices = (I0, I1, …, Ik‑1)`` this + operator selects elements from ``data`` as if one had written + ``data[I0, I1, …, Ik‑1]`` in NumPy/PyTorch: + + * All index tensors must have an integer dtype. + * Their shapes are broadcast together to a common shape ``B`` in + the usual NumPy way. + * The result shape is ``B + data.shape[k:]`` (i.e. the broadcast + shape followed by the remaining axes of ``data`` that are *not* + indexed). + * ``k`` must not exceed ``data.ndim``; otherwise a compile‑time + error is raised. -def index_tensor(data, indices): # TODO remove axis argument - """TODO docstring""" + Parameters + ---------- + data : tvm.te.Tensor + The tensor to be indexed. + + indices : Sequence[tvm.te.Tensor] + A Python ``list`` / ``tuple`` of **k** index tensors, + or a `tvm.te.Tensor` tuple expression. Each tensor must have an + integer dtype. + + Returns + ------- + result : tvm.te.Tensor + The tensor obtained after advanced indexing. Its dtype equals + ``data.dtype`` + Examples + -------- + .. code-block:: python + + x = te.placeholder((3, 3), name="x") # shape (3,3) + row = te.placeholder((2,), name="row", dtype="int32") + col = te.placeholder((2,), name="col", dtype="int32") + + # Equivalent to x[row, col] in NumPy / PyTorch + y = topi.index_tensor(x, [row, col]) # shape (2,) + + # Broadcasting example: + row = te.placeholder((2, 1), name="row", dtype="int32") + col = te.placeholder((1, 3), name="col", dtype="int32") + z = topi.index_tensor(x, [row, col]) # shape (2, 3) + """ return topi.adv_index(data, indices) diff --git a/src/relax/op/tensor/manipulate.h b/src/relax/op/tensor/manipulate.h index 0c7d482ffacb..7b6c8420170d 100644 --- a/src/relax/op/tensor/manipulate.h +++ b/src/relax/op/tensor/manipulate.h @@ -200,16 +200,15 @@ Expr gather_elements(Expr data, Expr indices, int axis = 0); */ Expr gather_nd(Expr data, Expr indices, int batch_dims = 0); -/*! // TODO update this comment - * \brief Gather values from a tensor using N-dimensional indices. +/*! + * \brief NumPy/PyTorch‑style advanced indexing with tensors. * \param data The input tensor. - * \param indices The indices tensor, must have integer type. - * \return The computed result. + * \param indices A Tuple expression (or list) containing the index tensors. + * \return The indexed tensor. * - * \note For batch_dims > 0, the first batch_dims dimensions of data and indices must be equal. - * The last dimension of indices indicates the depth of each index vector. - * The output shape is batch_dims + indices.shape[:-1] + data.shape[batch_dims + - * indices.shape[-1]:] + * \note When all shapes are static, Relax checks that the index shapes are + * broadcast-compatible. Bounds checking of the values in indices is + * deferred to runtime. */ Expr index_tensor(Expr data, Expr indices); From 8626f59a43dd1ee5ce3d29106a071320cc088d7e Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Wed, 16 Apr 2025 21:31:29 -0400 Subject: [PATCH 096/105] all pass --- python/tvm/topi/transform.py | 1 + src/relax/op/tensor/manipulate.cc | 1 - .../relax/test_from_exported_to_cuda.py | 65 ------------------- 3 files changed, 1 insertion(+), 66 deletions(-) diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py index d80a5b3438fd..52dc03100461 100644 --- a/python/tvm/topi/transform.py +++ b/python/tvm/topi/transform.py @@ -1053,6 +1053,7 @@ def _apply_trilu(*indices): return te.compute(data.shape, _apply_trilu, name="trilu", tag=topi.tag.ELEMWISE) + def index_tensor(data, indices): """Advanced‑tensor indexing (NumPy/PyTorch‐style). diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index 403e4bf438fe..624d0b884b48 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -30,7 +30,6 @@ #include #include -#include "tvm/relax/type.h" // kUnknownNDim #include "tvm/runtime/data_type.h" namespace tvm { diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index da4197d9ce39..75cb518369ac 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -62,23 +62,6 @@ def assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, tar desired = pytorch_out.detach().numpy() np.testing.assert_allclose(actual=actual, desired=desired, rtol=1e-5, atol=1e-5) - -@tvm.testing.parametrize_targets("cuda") -def test_full(target, dev): - class FullModel(nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - return torch.full((2, 3), 3.141592) - - torch_module = FullModel().eval() - - raw_data = np.random.rand(3, 3).astype("float32") - - assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) - - # Test index.Tensor # TODO aggregate into one big tet @@ -264,54 +247,6 @@ def forward(self, x): assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) -@tvm.testing.parametrize_targets("cuda") -def test_full(target, dev): - class FullModel(nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - return torch.full((2, 3), 3.141592) - - torch_module = FullModel().eval() - - raw_data = np.random.rand(3, 3).astype("float32") - - assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) - - -@tvm.testing.parametrize_targets("cuda") -def test_full_like(target, dev): - class FullLike(nn.Module): - def __init__(self): - super().__init__() - self.fill_value = 7.0 - - def forward(self, x): - return torch.full_like(x, self.fill_value) - - torch_module = FullLike().eval() - raw_data = np.random.rand(2, 3).astype("float32") - - assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) - - -@tvm.testing.parametrize_targets("cuda") -def test_ones(target, dev): - class FullModel(nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - return torch.ones((2, 3)) - - torch_module = FullModel().eval() - - raw_data = np.random.rand(1, 1).astype("float32") - - assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) - - @tvm.testing.parametrize_targets("cuda") def test_tensor_clamp(target, dev): class ClampBothTensor(torch.nn.Module): From 93c0ac19b7105fc46e653b1a2a1943fd64a414bf Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Wed, 16 Apr 2025 21:32:47 -0400 Subject: [PATCH 097/105] combine into one test. all pass --- .../relax/test_from_exported_to_cuda.py | 32 ++----------------- 1 file changed, 2 insertions(+), 30 deletions(-) diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index 75cb518369ac..57191b74e52e 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -62,11 +62,9 @@ def assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, tar desired = pytorch_out.detach().numpy() np.testing.assert_allclose(actual=actual, desired=desired, rtol=1e-5, atol=1e-5) -# Test index.Tensor # TODO aggregate into one big tet - @tvm.testing.parametrize_targets("cuda") -def test_index_tensor0(target, dev): +def test_index_tensor(target, dev): class IndexModel0(nn.Module): def __init__(self): super().__init__() @@ -80,9 +78,6 @@ def forward(self, x): assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) - -@tvm.testing.parametrize_targets("cuda") -def test_index_tensor1(target, dev): class IndexModel1(nn.Module): def __init__(self): super().__init__() @@ -96,9 +91,6 @@ def forward(self, x): assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) - -@tvm.testing.parametrize_targets("cuda") -def test_index_tensor2(target, dev): class IndexTensorModel2(nn.Module): def __init__(self): super().__init__() @@ -112,9 +104,6 @@ def forward(self, x): assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) - -@tvm.testing.parametrize_targets("cuda") -def test_index_tensor3(target, dev): class IndexTensorModel3(nn.Module): def __init__(self): super().__init__() @@ -126,9 +115,6 @@ def forward(self, x): raw_data = np.random.rand(5, 5, 5).astype("float32") assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) - -@tvm.testing.parametrize_targets("cuda") -def test_index_tensor4(target, dev): class IndexTensorModel4(nn.Module): def __init__(self): super().__init__() @@ -140,9 +126,6 @@ def forward(self, x): raw_data = np.random.rand(5, 5, 5).astype("float32") assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) - -@tvm.testing.parametrize_targets("cuda") -def test_index_tensor5(target, dev): class IndexTensorModel5(nn.Module): def __init__(self): super().__init__() @@ -154,9 +137,6 @@ def forward(self, x): raw_data = np.random.rand(5, 5, 5).astype("float32") assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) - -@tvm.testing.parametrize_targets("cuda") -def test_index_tensor6(target, dev): class IndexTensorModel6(nn.Module): def __init__(self): super().__init__() @@ -168,25 +148,17 @@ def forward(self, x): raw_data = np.random.rand(5, 5, 5, 5).astype("float32") assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) - -@tvm.testing.parametrize_targets("cuda") -def test_index_tensor7(target, dev): class IndexTensorModel7(nn.Module): def __init__(self): super().__init__() def forward(self, x): - return x[ - [[0, 1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 0]] - ] # both args[0] and indices are expr.Var + return x[[[0, 1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 0]]] torch_module = IndexTensorModel7().eval() raw_data = np.random.rand(5, 5, 5, 5).astype("float32") assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) - -@tvm.testing.parametrize_targets("cuda") -def test_index_tensor8(target, dev): class IndexTensorModel8(nn.Module): def __init__(self): super().__init__() From 76a3e212b0827e9bd56b5bfe8baa7ed487d7d559 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Wed, 16 Apr 2025 22:38:00 -0400 Subject: [PATCH 098/105] blank line --- python/tvm/relax/op/manipulate.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/tvm/relax/op/manipulate.py b/python/tvm/relax/op/manipulate.py index 1eb90cfdcb8a..3639f35f6606 100644 --- a/python/tvm/relax/op/manipulate.py +++ b/python/tvm/relax/op/manipulate.py @@ -563,6 +563,7 @@ def index_tensor(data: Expr, indices: Union[Expr, List[Expr]]) -> Expr: col = R.const(np.array([[0,1,2]])) z = R.index_tensor(x, [row, col]) # z.shape == (2,3) + """ if isinstance(indices, (list, tuple)): indices = RxTuple(indices) From 695edf1c405d356a69941495ce3806c064ffb464 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Thu, 17 Apr 2025 00:36:32 -0400 Subject: [PATCH 099/105] docs indentation --- python/tvm/relax/op/manipulate.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/python/tvm/relax/op/manipulate.py b/python/tvm/relax/op/manipulate.py index 3639f35f6606..952e2a45dfed 100644 --- a/python/tvm/relax/op/manipulate.py +++ b/python/tvm/relax/op/manipulate.py @@ -515,17 +515,18 @@ def index_tensor(data: Expr, indices: Union[Expr, List[Expr]]) -> Expr: operator selects elements from ``data`` as if one had written ``data[I0, I1, …, Ik‑1]`` in NumPy/PyTorch: - * All index tensors must have an integer dtype. - * Their shapes are broadcast together to a common shape ``B`` in - the usual NumPy way. - * The result shape is ``B + data.shape[k:]`` (i.e. the broadcast - shape followed by the remaining axes of ``data`` that are *not* - indexed). - * At compile‑time Relax checks - * the number of index tensors ``k`` does not exceed - ``data.ndim``, - * the dtypes are integer, - * the shapes are consitent (broadcast‑compatible). + All index tensors must have an integer dtype. + + Their shapes are broadcast together to a common shape ``B`` in + the usual NumPy way. + + The result shape is ``B + data.shape[k:]`` (i.e. the broadcast + shape followed by the remaining axes of ``data`` that are *not* + indexed). + + At compile‑time Relax checks that the number of index tensors + ``k`` does not exceed ``data.ndim``, that the dtypes are integer, + and that the shapes are consitent (broadcast‑compatible). Parameters ---------- From e505fecf452c0d37ba66b8a906514f203ce75f73 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Thu, 17 Apr 2025 00:49:48 -0400 Subject: [PATCH 100/105] lint --- python/tvm/relax/op/manipulate.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/tvm/relax/op/manipulate.py b/python/tvm/relax/op/manipulate.py index 952e2a45dfed..522abb40dcd3 100644 --- a/python/tvm/relax/op/manipulate.py +++ b/python/tvm/relax/op/manipulate.py @@ -516,15 +516,15 @@ def index_tensor(data: Expr, indices: Union[Expr, List[Expr]]) -> Expr: ``data[I0, I1, …, Ik‑1]`` in NumPy/PyTorch: All index tensors must have an integer dtype. - + Their shapes are broadcast together to a common shape ``B`` in the usual NumPy way. - + The result shape is ``B + data.shape[k:]`` (i.e. the broadcast shape followed by the remaining axes of ``data`` that are *not* indexed). - - At compile‑time Relax checks that the number of index tensors + + At compile‑time Relax checks that the number of index tensors ``k`` does not exceed ``data.ndim``, that the dtypes are integer, and that the shapes are consitent (broadcast‑compatible). From 788750216dac86ff5672c9e768637a7497cedd90 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Thu, 17 Apr 2025 08:15:01 -0400 Subject: [PATCH 101/105] dummy whitespace change to trigger tests --- tests/python/relax/test_from_exported_to_cuda.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index 57191b74e52e..05b81417a2fc 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -73,9 +73,7 @@ def forward(self, x): return x[torch.tensor([0])] torch_module = IndexModel0().eval() - raw_data = np.random.rand(3, 3).astype("float32") - assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) class IndexModel1(nn.Module): @@ -86,9 +84,7 @@ def forward(self, x): return x[torch.tensor([[0]])] torch_module = IndexModel1().eval() - raw_data = np.random.rand(2, 3).astype("float32") - assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) class IndexTensorModel2(nn.Module): @@ -99,9 +95,7 @@ def forward(self, x): return x[torch.tensor([0, 2])] torch_module = IndexTensorModel2().eval() - raw_data = np.random.rand(3, 4).astype("float32") - assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) class IndexTensorModel3(nn.Module): @@ -181,9 +175,7 @@ def forward(self, x): return torch.full((2, 3), 3.141592) torch_module = FullModel().eval() - raw_data = np.random.rand(3, 3).astype("float32") - assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) @@ -199,7 +191,6 @@ def forward(self, x): torch_module = FullLike().eval() raw_data = np.random.rand(2, 3).astype("float32") - assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) From 9f260f27a550bc4e2bbbaccdbaa849b4bd5dbf0e Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Thu, 17 Apr 2025 08:16:10 -0400 Subject: [PATCH 102/105] whitespace --- tests/python/relax/test_from_exported_to_cuda.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index 05b81417a2fc..b5a6e59dbe5d 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -682,7 +682,6 @@ def forward(self, x): return new_vec.sum() torch_module = SumModel().eval() - raw_data = np.random.rand(10, 10, 10).astype("float32") assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) @@ -699,7 +698,6 @@ def forward(self, x): torch_module = MulModule().eval() raw_data = np.random.rand(2, 3).astype("float32") - assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) From d70fba18128ffd1da936765ef063497cd2aa1241 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Thu, 17 Apr 2025 08:16:35 -0400 Subject: [PATCH 103/105] whitespace --- tests/python/relax/test_from_exported_to_cuda.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index b5a6e59dbe5d..8f4227325d65 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -715,9 +715,7 @@ def forward(self, x): return torch.cat((x, self.x2, self.x3, self.x4), dim=self.dim) torch_module = ConcatFour().eval() - raw_data = np.random.rand(2, 3).astype("float32") - assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) From d620c4f3abf6ddd16e1e52cd6e7ba72d28c4b106 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Thu, 17 Apr 2025 09:20:44 -0400 Subject: [PATCH 104/105] dummy whitespace change to trigger tests --- tests/python/relax/test_from_exported_to_cuda.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index 8f4227325d65..76a4bb203925 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -204,9 +204,7 @@ def forward(self, x): return torch.ones((2, 3)) torch_module = FullModel().eval() - raw_data = np.random.rand(1, 1).astype("float32") - assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) From 46841a3b9c1a31d50121ad9d0693d9125f76dd9f Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sat, 19 Apr 2025 10:43:12 -0400 Subject: [PATCH 105/105] no backtracking --- python/tvm/relax/transform/legalize_ops/manipulate.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/python/tvm/relax/transform/legalize_ops/manipulate.py b/python/tvm/relax/transform/legalize_ops/manipulate.py index 68ed9d423a92..a22a82ebbeb0 100644 --- a/python/tvm/relax/transform/legalize_ops/manipulate.py +++ b/python/tvm/relax/transform/legalize_ops/manipulate.py @@ -167,16 +167,7 @@ def te_gather_nd(data, indices, batch_dims): def _index_tensor(bb: BlockBuilder, call: Call) -> Expr: t = call.args[1] n_field = len(t.struct_info.fields) - while isinstance(t, Var): - binding = bb.lookup_binding(t) - if not isinstance(binding, (Tuple, Var)): - break - t = binding - - assert isinstance(t, (Tuple, Var)) - fields = ( - t.fields if isinstance(t, Tuple) else [bb.emit(TupleGetItem(t, i)) for i in range(n_field)] - ) + fields = [bb.emit(TupleGetItem(t, i)) for i in range(n_field)] return bb.call_te(topi.index_tensor, call.args[0], fields)