From 0c8c9046419822a0028611564f01dca778d4c590 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 16 Mar 2025 02:34:34 -0400 Subject: [PATCH 01/18] suddenly copy.default is unsupported --- tests/python/relax/test_from_exported_to_cuda.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index 6cc12370d648..7f0204cd2e1f 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -202,6 +202,11 @@ def forward(self, x): assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) +@tvm.testing.parametrize_targets("cuda") +def test_copy_fails_(target, dev): + assert False, "test_copy_fails_ indeed fails" + + @tvm.testing.parametrize_targets("cuda") def test_upsample_with_size(target, dev): """ From f7f063786d0f26ba1ed79b6d1ffaaf858055f54b Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 16 Mar 2025 14:58:58 -0400 Subject: [PATCH 02/18] wip --- .../torch/base_fx_graph_translator.py | 20 +++++++++++++++++++ .../torch/exported_program_translator.py | 1 + 2 files changed, 21 insertions(+) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 6bbc9d5de618..40342a830674 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1021,6 +1021,7 @@ def _scatter(self, node: fx.Node) -> relax.Var: return self.block_builder.emit(relax.op.scatter_elements(x, index, src, axis=dim)) def _split(self, node: fx.Node) -> relax.Var: + """ torch.split with split_size passed as an argument""" x = self.env[node.args[0]] split_size = node.args[1] dim = node.args[2] if len(node.args) > 2 else node.kwargs.get("dim", 0) @@ -1031,6 +1032,25 @@ def _split(self, node: fx.Node) -> relax.Var: n_section.append(s + cum_sum) else: n_section = (self.shape_of(x)[dim].value + split_size - 1) // split_size + print("calling split with n_section", n_section) + print("calling split with dim", dim) + return self.block_builder.emit(relax.op.split(x, n_section, dim)) + + def _split_with_sizes(self, node: fx.Node) -> relax.Var: + """ torch.split with a list of section sizes passed as an argument""" + # TODO + x = self.env[node.args[0]] + split_size = node.args[1] + dim = node.args[2] if len(node.args) > 2 else node.kwargs.get("dim", 0) + if isinstance(split_size, (list, tuple)): + n_section = [] + for s in split_size[:-1]: + cum_sum = 0 if not n_section else n_section[-1] + n_section.append(s + cum_sum) + else: + n_section = (self.shape_of(x)[dim].value + split_size - 1) // split_size + print("calling split with n_section", n_section) + print("calling split with dim", dim) return self.block_builder.emit(relax.op.split(x, n_section, dim)) def _squeeze(self, node: fx.Node) -> relax.Var: diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index bc7a4c4cb046..834002a02170 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -329,6 +329,7 @@ def create_convert_map( "select.int": self._select, "slice.Tensor": self._slice, "split.Tensor": self._split, + "split_with_sizes.default": self._split_with_sizes, "squeeze.default": self._squeeze, "squeeze.dim": self._squeeze, "take.default": self._take, From 7e4cf058ba43bd49904905362301dd26d50a6b68 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 16 Mar 2025 16:48:01 -0400 Subject: [PATCH 03/18] Able to split uneven tensors! Remaining TODOs Make sure that also works if we input a list of indices Write tests at the py and cpp level Cleanup --- include/tvm/topi/transform.h | 32 +++++++++++-------- python/tvm/relax/op/manipulate.py | 1 + .../transform/legalize_ops/manipulate.py | 15 +++++---- .../framework/tensorrt/transform_tensorrt.cc | 2 ++ src/relax/op/tensor/manipulate.cc | 4 +++ src/topi/transform.cc | 7 ++-- 6 files changed, 39 insertions(+), 22 deletions(-) diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index faacd2ce5760..62f69d9c825e 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -575,8 +575,19 @@ inline Tensor stack(const Array& inputs, int axis = 0, std::string name * * \return A Tensor whose op member is the split operation */ -inline Array split(const Tensor& x, Array split_indices, int axis, - std::string name = "T_split", std::string tag = kInjective) { +inline Array split_indices_array(const Tensor& x, Array split_indices, int axis, + std::string name = "T_split", + std::string tag = kInjective) { + printf("we are in the transform'h's split\n"); + // int x11 = 10; + // int y11 = 5; + // while (y11) { + // x11 -= 2; + // y11--; + // } + // int z11 = 50 / x11; + // printf("z11 is %d\n", z11); + if (axis < 0) { axis += static_cast(x->shape.size()); } @@ -968,9 +979,10 @@ inline Tensor strided_slice(const Tensor& x, const Array& begin, const * * \return A Tensor whose op member is the split operation */ -inline Array split_sections(const Tensor& x, int num_sections, int axis, - std::string name = "T_split_sections", - std::string tag = kInjective) { +inline Array split_n_sections(const Tensor& x, int num_sections, int axis, + std::string name = "T_split_sections", + std::string tag = kInjective) { + printf("We are in transform.h's splits_sections\n"); if (axis < 0) { axis += static_cast(x->shape.size()); } @@ -980,14 +992,8 @@ inline Array split_sections(const Tensor& x, int num_sections, int axis, ICHECK_GT(num_sections, 0) << "Slice count must be > 0"; - if (auto node = src_axis_size.as()) { - ICHECK_EQ(node->value % num_sections, 0) - << "num_sections must be an integer factor of the size of axis " << axis << " (" - << node->value << ")"; - } - Array split_indices; - auto seg_size = indexdiv(src_axis_size, num_sections); + auto seg_size = indexdiv(src_axis_size + num_sections - 1, num_sections); for (int i = 0; i < num_sections; ++i) { // region at index 0 is added by split() if (i != 0) { @@ -995,7 +1001,7 @@ inline Array split_sections(const Tensor& x, int num_sections, int axis, } } - return split(x, split_indices, axis, name, tag); + return split_indices_array(x, split_indices, axis, name, tag); } /*! diff --git a/python/tvm/relax/op/manipulate.py b/python/tvm/relax/op/manipulate.py index 0f6e537ab3d6..85f08b7f98bd 100644 --- a/python/tvm/relax/op/manipulate.py +++ b/python/tvm/relax/op/manipulate.py @@ -253,6 +253,7 @@ def split( """ if isinstance(indices_or_sections, int): indices_or_sections = IntImm("int64", indices_or_sections) + print("CALLING _ffi_api.split !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") return _ffi_api.split(x, indices_or_sections, axis) # type: ignore diff --git a/python/tvm/relax/transform/legalize_ops/manipulate.py b/python/tvm/relax/transform/legalize_ops/manipulate.py index c71a41dc1c2d..cefbba55fce3 100644 --- a/python/tvm/relax/transform/legalize_ops/manipulate.py +++ b/python/tvm/relax/transform/legalize_ops/manipulate.py @@ -107,18 +107,19 @@ def _permute_dims(bb: BlockBuilder, call: Call) -> Expr: @register_legalize("relax.split") def _split(bb: BlockBuilder, call: Call) -> Expr: + print("is this the split op??????????????????????????????????/") if isinstance(call.attrs.indices_or_sections, tir.IntImm): indices_or_sections = call.attrs.indices_or_sections.value modulo = tvm.arith.Analyzer().simplify( call.args[0].struct_info.shape.values[call.attrs.axis] % indices_or_sections ) - if isinstance(modulo, tir.IntImm): - if modulo != 0: - logging.info( - "Split cannot be legalized by TOPI when the axis being split has " - "length that not divisible by the input number of section." - ) - return call + # if isinstance(modulo, tir.IntImm): + # if modulo != 0: + # logging.info( + # "Split cannot be legalized by TOPI when the axis being split has " + # "length that not divisible by the input number of section." + # ) + # return call else: indices_or_sections = call.attrs.indices_or_sections return bb.call_te(topi.split, call.args[0], indices_or_sections, call.attrs.axis) diff --git a/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc b/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc index 94dfec7ea621..92d090946cdb 100644 --- a/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc +++ b/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc @@ -782,6 +782,8 @@ Expr RewriteShapeLike(BlockBuilder builder, const Var& var, const Call& src_call Expr RewriteSplit(BlockBuilder builder, const Var& var, const Call& src_call, const Map& new_calls, const String& config) { + printf("INSIDE REWRITESPLIT !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!11\n"); + std::cout << "Why is this not building?" << std::endl; const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; const auto& input_shape = ExprUtils::GetShape(call->args[0]); const auto* src_attrs = src_call->attrs.as(); diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index cb738db363ee..f40f0ac36e0a 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -824,7 +824,10 @@ TVM_REGISTER_OP("relax.reshape") /* relax.split */ TVM_REGISTER_NODE_TYPE(SplitAttrs); +#include + Expr split(Expr x, Variant> indices_or_sections, int axis) { + printf("INSIDE SPLIT!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n"); ObjectPtr attrs = make_object(); if (const auto* indices = indices_or_sections.as()) { for (int i = 0; i < static_cast(indices->size()); ++i) { @@ -848,6 +851,7 @@ Expr split(Expr x, Variant> indices_or_sections, int axis) attrs->axis = axis; static const Op& op = Op::Get("relax.split"); + printf("CALLING OP !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n"); return Call(op, {std::move(x)}, Attrs(attrs), {}); } diff --git a/src/topi/transform.cc b/src/topi/transform.cc index 2e0fde3b289f..a0df64faa291 100644 --- a/src/topi/transform.cc +++ b/src/topi/transform.cc @@ -82,11 +82,14 @@ TVM_REGISTER_GLOBAL("topi.ndarray_size").set_body([](TVMArgs args, TVMRetValue* *rv = ndarray_size(args[0], args[1]); }); +#include TVM_REGISTER_GLOBAL("topi.split").set_body([](TVMArgs args, TVMRetValue* rv) { + printf("we are in transform.cc's topi.split, called by _split python func in manipulate.py\n"); + if (args[1].type_code() == kDLInt || args[1].type_code() == kDLUInt) { - *rv = split_sections(args[0], args[1], args[2]); + *rv = split_n_sections(args[0], args[1], args[2]); } else { - *rv = split(args[0], args[1], args[2]); + *rv = split_indices_array(args[0], args[1], args[2]); } }); From dcbee0c81befa6a371215832bc50e8db07a12d9e Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 16 Mar 2025 17:31:54 -0400 Subject: [PATCH 04/18] split size test passes! --- .../relax/test_from_exported_to_cuda.py | 45 +++++++++++++++---- 1 file changed, 36 insertions(+), 9 deletions(-) diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index 7f0204cd2e1f..be438cac0e05 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -15,6 +15,9 @@ # specific language governing permissions and limitations # under the License. +import sys +sys.path.append('/ssd1/htalendr/tvm/python') # Refer to local TVM build + import tvm from tvm import relax import tvm.testing @@ -50,10 +53,17 @@ def assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, tar gpu_params = [tvm.nd.array(p, dev) for p in tvm_params["main"]] gpu_out = vm["main"](gpu_data, *gpu_params) - pytorch_out = torch_module(torch_data).detach().numpy() - actual = gpu_out[0].numpy() - desired = pytorch_out - np.testing.assert_allclose(actual=actual, desired=desired, rtol=1e-5, atol=1e-5) + pytorch_out = torch_module(torch_data) + + if isinstance(pytorch_out, tuple): + for i in range(len(pytorch_out)): + actual = gpu_out[i].numpy() + desired = pytorch_out[i].detach().numpy() + np.testing.assert_allclose(actual=actual, desired=desired, rtol=1e-5, atol=1e-5) + else: + actual = gpu_out[0].numpy() + desired = pytorch_out.detach().numpy() + np.testing.assert_allclose(actual=actual, desired=desired, rtol=1e-5, atol=1e-5) @tvm.testing.parametrize_targets("cuda") @@ -202,11 +212,6 @@ def forward(self, x): assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) -@tvm.testing.parametrize_targets("cuda") -def test_copy_fails_(target, dev): - assert False, "test_copy_fails_ indeed fails" - - @tvm.testing.parametrize_targets("cuda") def test_upsample_with_size(target, dev): """ @@ -285,6 +290,28 @@ def forward(self, x): assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module2, target, dev) assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module3, target, dev) +@tvm.testing.parametrize_targets("cuda") +def test_split_size(target, dev): + + channels = 7 + split_size = 3 + dim = 0 # TODO try higher dims! + raw_data = np.random.rand(channels).astype("float32") + + class SplitModelSplitSize(nn.Module): + def __init__(self, split_size, dim): + super().__init__() + self.split_size = split_size + self.dim = dim + + def forward(self, x): + return torch.split(x, split_size_or_sections=self.split_size, dim=self.dim) + + torch_module = SplitModelSplitSize(split_size=split_size, dim=dim).eval() + + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + if __name__ == "__main__": tvm.testing.main() From 75890ce59dfe85f99c6053734295614329335734 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 16 Mar 2025 17:46:05 -0400 Subject: [PATCH 05/18] test sizes and lists --- .../relax/test_from_exported_to_cuda.py | 35 ++++++++++++++++--- 1 file changed, 31 insertions(+), 4 deletions(-) diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index be438cac0e05..482344f2dcb1 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -292,11 +292,14 @@ def forward(self, x): @tvm.testing.parametrize_targets("cuda") def test_split_size(target, dev): - + # Test split using the split_size argument such that it is not a divisor + # of the dimension to split (the last tensor will be smaller) + batch = 2 channels = 7 - split_size = 3 - dim = 0 # TODO try higher dims! - raw_data = np.random.rand(channels).astype("float32") + height, width = 2, 2 + split_size = 3 # last tensor will have just 1 element + dim = 1 # split across channels + raw_data = np.random.rand(batch, channels, height, width).astype("float32") class SplitModelSplitSize(nn.Module): def __init__(self, split_size, dim): @@ -311,6 +314,30 @@ def forward(self, x): assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) +@tvm.testing.parametrize_targets("cuda") +def test_split_sections_list(target, dev): + # Test split using a list of section sizes + batch = 3 + channels = 2 + height = 10 + width = 5 + sections = [3, 2, 5] + dim = 2 # split across height + raw_data = np.random.rand(batch, channels, height, width).astype("float32") + + class SplitModelSectionsList(nn.Module): + def __init__(self, split_size, dim): + super().__init__() + self.split_size = split_size + self.dim = dim + + def forward(self, x): + return torch.split(x, split_size_or_sections=self.split_size, dim=self.dim) + + torch_module = SplitModelSectionsList(split_size=sections, dim=dim).eval() + + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + if __name__ == "__main__": From 5a8eab116c27019d8796a814e1ffbd28a7d72fca Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 16 Mar 2025 17:48:28 -0400 Subject: [PATCH 06/18] just one func --- include/tvm/topi/transform.h | 10 ---------- .../torch/base_fx_graph_translator.py | 20 ------------------- .../torch/exported_program_translator.py | 2 +- 3 files changed, 1 insertion(+), 31 deletions(-) diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index 62f69d9c825e..e459c0d43a81 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -578,16 +578,6 @@ inline Tensor stack(const Array& inputs, int axis = 0, std::string name inline Array split_indices_array(const Tensor& x, Array split_indices, int axis, std::string name = "T_split", std::string tag = kInjective) { - printf("we are in the transform'h's split\n"); - // int x11 = 10; - // int y11 = 5; - // while (y11) { - // x11 -= 2; - // y11--; - // } - // int z11 = 50 / x11; - // printf("z11 is %d\n", z11); - if (axis < 0) { axis += static_cast(x->shape.size()); } diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 40342a830674..6bbc9d5de618 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1021,7 +1021,6 @@ def _scatter(self, node: fx.Node) -> relax.Var: return self.block_builder.emit(relax.op.scatter_elements(x, index, src, axis=dim)) def _split(self, node: fx.Node) -> relax.Var: - """ torch.split with split_size passed as an argument""" x = self.env[node.args[0]] split_size = node.args[1] dim = node.args[2] if len(node.args) > 2 else node.kwargs.get("dim", 0) @@ -1032,25 +1031,6 @@ def _split(self, node: fx.Node) -> relax.Var: n_section.append(s + cum_sum) else: n_section = (self.shape_of(x)[dim].value + split_size - 1) // split_size - print("calling split with n_section", n_section) - print("calling split with dim", dim) - return self.block_builder.emit(relax.op.split(x, n_section, dim)) - - def _split_with_sizes(self, node: fx.Node) -> relax.Var: - """ torch.split with a list of section sizes passed as an argument""" - # TODO - x = self.env[node.args[0]] - split_size = node.args[1] - dim = node.args[2] if len(node.args) > 2 else node.kwargs.get("dim", 0) - if isinstance(split_size, (list, tuple)): - n_section = [] - for s in split_size[:-1]: - cum_sum = 0 if not n_section else n_section[-1] - n_section.append(s + cum_sum) - else: - n_section = (self.shape_of(x)[dim].value + split_size - 1) // split_size - print("calling split with n_section", n_section) - print("calling split with dim", dim) return self.block_builder.emit(relax.op.split(x, n_section, dim)) def _squeeze(self, node: fx.Node) -> relax.Var: diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 834002a02170..2abc0b024871 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -329,7 +329,7 @@ def create_convert_map( "select.int": self._select, "slice.Tensor": self._slice, "split.Tensor": self._split, - "split_with_sizes.default": self._split_with_sizes, + "split_with_sizes.default": self._split, "squeeze.default": self._squeeze, "squeeze.dim": self._squeeze, "take.default": self._take, From c771389462f4620d4ed86194109975226bd5468f Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 16 Mar 2025 17:52:22 -0400 Subject: [PATCH 07/18] cleanup --- include/tvm/topi/transform.h | 1 - python/tvm/relax/op/manipulate.py | 1 - python/tvm/relax/transform/legalize_ops/manipulate.py | 8 -------- src/contrib/msc/framework/tensorrt/transform_tensorrt.cc | 2 -- src/relax/op/tensor/manipulate.cc | 2 -- src/topi/transform.cc | 3 --- tests/python/relax/test_from_exported_to_cuda.py | 2 -- 7 files changed, 19 deletions(-) diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index e459c0d43a81..762148dcfac3 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -972,7 +972,6 @@ inline Tensor strided_slice(const Tensor& x, const Array& begin, const inline Array split_n_sections(const Tensor& x, int num_sections, int axis, std::string name = "T_split_sections", std::string tag = kInjective) { - printf("We are in transform.h's splits_sections\n"); if (axis < 0) { axis += static_cast(x->shape.size()); } diff --git a/python/tvm/relax/op/manipulate.py b/python/tvm/relax/op/manipulate.py index 85f08b7f98bd..0f6e537ab3d6 100644 --- a/python/tvm/relax/op/manipulate.py +++ b/python/tvm/relax/op/manipulate.py @@ -253,7 +253,6 @@ def split( """ if isinstance(indices_or_sections, int): indices_or_sections = IntImm("int64", indices_or_sections) - print("CALLING _ffi_api.split !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") return _ffi_api.split(x, indices_or_sections, axis) # type: ignore diff --git a/python/tvm/relax/transform/legalize_ops/manipulate.py b/python/tvm/relax/transform/legalize_ops/manipulate.py index cefbba55fce3..55e3aeb4423c 100644 --- a/python/tvm/relax/transform/legalize_ops/manipulate.py +++ b/python/tvm/relax/transform/legalize_ops/manipulate.py @@ -107,19 +107,11 @@ def _permute_dims(bb: BlockBuilder, call: Call) -> Expr: @register_legalize("relax.split") def _split(bb: BlockBuilder, call: Call) -> Expr: - print("is this the split op??????????????????????????????????/") if isinstance(call.attrs.indices_or_sections, tir.IntImm): indices_or_sections = call.attrs.indices_or_sections.value modulo = tvm.arith.Analyzer().simplify( call.args[0].struct_info.shape.values[call.attrs.axis] % indices_or_sections ) - # if isinstance(modulo, tir.IntImm): - # if modulo != 0: - # logging.info( - # "Split cannot be legalized by TOPI when the axis being split has " - # "length that not divisible by the input number of section." - # ) - # return call else: indices_or_sections = call.attrs.indices_or_sections return bb.call_te(topi.split, call.args[0], indices_or_sections, call.attrs.axis) diff --git a/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc b/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc index 92d090946cdb..94dfec7ea621 100644 --- a/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc +++ b/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc @@ -782,8 +782,6 @@ Expr RewriteShapeLike(BlockBuilder builder, const Var& var, const Call& src_call Expr RewriteSplit(BlockBuilder builder, const Var& var, const Call& src_call, const Map& new_calls, const String& config) { - printf("INSIDE REWRITESPLIT !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!11\n"); - std::cout << "Why is this not building?" << std::endl; const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; const auto& input_shape = ExprUtils::GetShape(call->args[0]); const auto* src_attrs = src_call->attrs.as(); diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index f40f0ac36e0a..6489ce36b47b 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -827,7 +827,6 @@ TVM_REGISTER_NODE_TYPE(SplitAttrs); #include Expr split(Expr x, Variant> indices_or_sections, int axis) { - printf("INSIDE SPLIT!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n"); ObjectPtr attrs = make_object(); if (const auto* indices = indices_or_sections.as()) { for (int i = 0; i < static_cast(indices->size()); ++i) { @@ -851,7 +850,6 @@ Expr split(Expr x, Variant> indices_or_sections, int axis) attrs->axis = axis; static const Op& op = Op::Get("relax.split"); - printf("CALLING OP !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n"); return Call(op, {std::move(x)}, Attrs(attrs), {}); } diff --git a/src/topi/transform.cc b/src/topi/transform.cc index a0df64faa291..7ef63a9b3f56 100644 --- a/src/topi/transform.cc +++ b/src/topi/transform.cc @@ -82,10 +82,7 @@ TVM_REGISTER_GLOBAL("topi.ndarray_size").set_body([](TVMArgs args, TVMRetValue* *rv = ndarray_size(args[0], args[1]); }); -#include TVM_REGISTER_GLOBAL("topi.split").set_body([](TVMArgs args, TVMRetValue* rv) { - printf("we are in transform.cc's topi.split, called by _split python func in manipulate.py\n"); - if (args[1].type_code() == kDLInt || args[1].type_code() == kDLUInt) { *rv = split_n_sections(args[0], args[1], args[2]); } else { diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index 482344f2dcb1..d629d3fc87d6 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -15,8 +15,6 @@ # specific language governing permissions and limitations # under the License. -import sys -sys.path.append('/ssd1/htalendr/tvm/python') # Refer to local TVM build import tvm from tvm import relax From 2fbe4c1cc03e4e31369394741d2e149899f9c9d6 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 16 Mar 2025 17:56:57 -0400 Subject: [PATCH 08/18] no assert --- src/relax/op/tensor/manipulate.cc | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index 6489ce36b47b..cb738db363ee 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -824,8 +824,6 @@ TVM_REGISTER_OP("relax.reshape") /* relax.split */ TVM_REGISTER_NODE_TYPE(SplitAttrs); -#include - Expr split(Expr x, Variant> indices_or_sections, int axis) { ObjectPtr attrs = make_object(); if (const auto* indices = indices_or_sections.as()) { From e5095b8dec610e78c4d7435b05b6d4c6727402ee Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 16 Mar 2025 18:01:10 -0400 Subject: [PATCH 09/18] linting --- tests/python/relax/test_from_exported_to_cuda.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index d629d3fc87d6..c120eb89811c 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -288,15 +288,16 @@ def forward(self, x): assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module2, target, dev) assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module3, target, dev) + @tvm.testing.parametrize_targets("cuda") def test_split_size(target, dev): - # Test split using the split_size argument such that it is not a divisor + # Test split using the split_size argument such that it is not a divisor # of the dimension to split (the last tensor will be smaller) batch = 2 channels = 7 height, width = 2, 2 - split_size = 3 # last tensor will have just 1 element - dim = 1 # split across channels + split_size = 3 # last tensor will have just 1 element + dim = 1 # split across channels raw_data = np.random.rand(batch, channels, height, width).astype("float32") class SplitModelSplitSize(nn.Module): @@ -312,6 +313,7 @@ def forward(self, x): assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + @tvm.testing.parametrize_targets("cuda") def test_split_sections_list(target, dev): # Test split using a list of section sizes @@ -320,7 +322,7 @@ def test_split_sections_list(target, dev): height = 10 width = 5 sections = [3, 2, 5] - dim = 2 # split across height + dim = 2 # split across height raw_data = np.random.rand(batch, channels, height, width).astype("float32") class SplitModelSectionsList(nn.Module): @@ -337,6 +339,5 @@ def forward(self, x): assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) - if __name__ == "__main__": tvm.testing.main() From 490a454c20f429fe54425182fde49bd5ccedf4a6 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 16 Mar 2025 18:19:32 -0400 Subject: [PATCH 10/18] chunk --- .../torch/base_fx_graph_translator.py | 8 ++++++ .../torch/exported_program_translator.py | 1 + .../relax/test_from_exported_to_cuda.py | 25 +++++++++++++++++++ 3 files changed, 34 insertions(+) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 6bbc9d5de618..90598640aa21 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -935,6 +935,14 @@ def _cat(self, node: fx.Node) -> relax.Var: axis = args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0) return self.block_builder.emit(relax.op.concat(args[0], axis=axis)) + def _chunk(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + chunks = node.args[1] + dim = node.args[2] if len(node.args) > 2 else node.kwargs.get("dim", 0) + length_dim = int(self.shape_of(x)[dim]) + n_section = math.ceil(length_dim / chunks) + return self.block_builder.emit(relax.op.split(x, n_section, dim)) + def _cumsum(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 2abc0b024871..23d4063748f2 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -316,6 +316,7 @@ def create_convert_map( "argmin.default": self._argmax_argmin(relax.op.argmin), # tensor manipulation "cat.default": self._cat, + "chunk.default": self._chunk, "clamp.Tensor": self._clamp, "concat.default": self._cat, "copy_.default": self._copy_, diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index c120eb89811c..d834b54b51a1 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -15,6 +15,10 @@ # specific language governing permissions and limitations # under the License. +import sys + +sys.path.append("/ssd1/htalendr/tvm/python") # Refer to local TVM build + import tvm from tvm import relax @@ -336,6 +340,27 @@ def forward(self, x): torch_module = SplitModelSectionsList(split_size=sections, dim=dim).eval() + +@tvm.testing.parametrize_targets("cuda") +def test_chunk(target, dev): + batch = 3 + channels = 5 + height = 7 + width = 11 + chunks = 2 + dim = 1 + raw_data = np.random.rand(batch, channels, height, width).astype("float32") + + class ChunkModel(nn.Module): + def __init__(self, chunks, dim): + super().__init__() + self.chunks = chunks + self.dim = dim + + def forward(self, x): + return x.chunk(self.chunks, dim=self.dim) + + torch_module = ChunkModel(chunks=chunks, dim=dim).eval() assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) From ec7311cdb03f86686f19412a20acb8d4eed41fcc Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Thu, 20 Mar 2025 08:51:06 -0400 Subject: [PATCH 11/18] remove unsused modulo --- python/tvm/relax/transform/legalize_ops/manipulate.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/python/tvm/relax/transform/legalize_ops/manipulate.py b/python/tvm/relax/transform/legalize_ops/manipulate.py index 55e3aeb4423c..662d4e946b5f 100644 --- a/python/tvm/relax/transform/legalize_ops/manipulate.py +++ b/python/tvm/relax/transform/legalize_ops/manipulate.py @@ -16,7 +16,6 @@ # under the License. # pylint: disable=invalid-name """Default legalization function for manipulate operators.""" -import logging from typing import Optional import tvm @@ -109,9 +108,6 @@ def _permute_dims(bb: BlockBuilder, call: Call) -> Expr: def _split(bb: BlockBuilder, call: Call) -> Expr: if isinstance(call.attrs.indices_or_sections, tir.IntImm): indices_or_sections = call.attrs.indices_or_sections.value - modulo = tvm.arith.Analyzer().simplify( - call.args[0].struct_info.shape.values[call.attrs.axis] % indices_or_sections - ) else: indices_or_sections = call.attrs.indices_or_sections return bb.call_te(topi.split, call.args[0], indices_or_sections, call.attrs.axis) From 0744701366475bc876dc68aff51d1a612095b48b Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Thu, 20 Mar 2025 11:13:34 -0400 Subject: [PATCH 12/18] fixed first test --- .../test_transform_legalize_ops_manipulate.py | 139 +++++++++++------- 1 file changed, 86 insertions(+), 53 deletions(-) diff --git a/tests/python/relax/test_transform_legalize_ops_manipulate.py b/tests/python/relax/test_transform_legalize_ops_manipulate.py index 0565b7a5790a..db1b5c08ceca 100644 --- a/tests/python/relax/test_transform_legalize_ops_manipulate.py +++ b/tests/python/relax/test_transform_legalize_ops_manipulate.py @@ -15,6 +15,9 @@ # specific language governing permissions and limitations # under the License. +import sys +sys.path.append('/ssd1/htalendr/tvm/python') + import tvm from tvm import relax from tvm.relax.transform import LegalizeOps @@ -788,102 +791,132 @@ def test_split_by_indices_n_section_indivisible(): class Split: @R.function def main(x: R.Tensor((2, 10, 4), "float32")) -> R.Tuple([R.Tensor((2, 4, 4), "float32"), R.Tensor((2, 4, 4), "float32"), R.Tensor((2, 2, 4), "float32")]): - gv: R.Tuple([R.Tensor((2, 4, 4), "float32"), R.Tensor((2, 4, 4), "float32"), R.Tensor((2, 2, 4), "float32")]) = R.split(x, 3, axis=1) - return gv - # fmt: on - - mod = LegalizeOps()(Split) - tvm.ir.assert_structural_equal(mod, Split) - - -def test_split_by_indices_n_section_divisible(): - # fmt: off - @tvm.script.ir_module - class Split: - @R.function - def main(x: R.Tensor((2, 10, 4), "float32")) -> R.Tuple([R.Tensor((2, 5, 4), "float32"), R.Tensor((2, 5, 4), "float32")]): - gv: R.Tuple([R.Tensor((2, 5, 4), "float32"), R.Tensor((2, 5, 4), "float32")]) = R.split(x, 2, axis=1) + gv: R.Tuple([R.Tensor((2, 4, 4), "float32"), R.Tensor((2, 4, 4), "float32"), R.Tensor((2, 2, 4), "float32")]) = R.split(x, indices_or_sections=3, axis=1) return gv @tvm.script.ir_module class Expected: @R.function - def main(x: R.Tensor((2, 10, 4), "float32")) -> R.Tuple([R.Tensor((2, 5, 4), "float32"), R.Tensor((2, 5, 4), "float32")]): - gv = R.call_tir(Expected.split, (x,), [R.Tensor((2, 5, 4), "float32"), R.Tensor((2, 5, 4), "float32")]) + def main(x: R.Tensor((2, 10, 4), "float32")) -> R.Tuple([R.Tensor((2, 4, 4), "float32"), R.Tensor((2, 4, 4), "float32"), R.Tensor((2, 2, 4), "float32")]): + gv = R.call_tir(Expected.split, (x,), [R.Tensor((2, 4, 4), "float32"), R.Tensor((2, 4, 4), "float32"), R.Tensor((2, 2, 4), "float32")]) return gv @T.prim_func(private=True) - def split(rxplaceholder: T.Buffer((T.int64(2), T.int64(10), T.int64(4)), "float32"), T_split_sections: T.Buffer((T.int64(2), T.int64(5), T.int64(4)), "float32"), T_split_sections_1: T.Buffer((T.int64(2), T.int64(5), T.int64(4)), "float32")): + def split(rxplaceholder: T.Buffer((T.int64(2), T.int64(10), T.int64(4)), "float32"), T_split_sections: T.Buffer((T.int64(2), T.int64(4), T.int64(4)), "float32"), T_split_sections_1: T.Buffer((T.int64(2), T.int64(4), T.int64(4)), "float32"), T_split_sections_2: T.Buffer((T.int64(2), T.int64(2), T.int64(4)), "float32")): T.func_attr({"tir.noalias": True}) - for i0, i1, i2 in T.grid(T.int64(2), T.int64(5), T.int64(4)): + for i0, i1, i2 in T.grid(T.int64(2), T.int64(4), T.int64(4)): with T.block("T_split_sections"): ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2]) T.reads(rxplaceholder[ax0, ax1, ax2]) T.writes(T_split_sections[ax0, ax1, ax2]) T_split_sections[ax0, ax1, ax2] = rxplaceholder[ax0, ax1, ax2] - for i0, i1, i2 in T.grid(T.int64(2), T.int64(5), T.int64(4)): + for i0, i1, i2 in T.grid(T.int64(2), T.int64(4), T.int64(4)): with T.block("T_split_sections_1"): ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2]) - T.reads(rxplaceholder[ax0, ax1 + T.int64(5), ax2]) + T.reads(rxplaceholder[ax0, ax1 + T.int64(4), ax2]) T.writes(T_split_sections_1[ax0, ax1, ax2]) - T_split_sections_1[ax0, ax1, ax2] = rxplaceholder[ax0, ax1 + T.int64(5), ax2] + T_split_sections_1[ax0, ax1, ax2] = rxplaceholder[ax0, ax1 + T.int64(4), ax2] + for i0, i1, i2 in T.grid(T.int64(2), T.int64(2), T.int64(4)): + with T.block("T_split_sections_2"): + ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(rxplaceholder[ax0, ax1 + T.int64(8), ax2]) + T.writes(T_split_sections_2[ax0, ax1, ax2]) + T_split_sections_2[ax0, ax1, ax2] = rxplaceholder[ax0, ax1 + T.int64(8), ax2] + # fmt: on mod = LegalizeOps()(Split) tvm.ir.assert_structural_equal(mod, Expected) -def test_split_by_indices_n_section_divisible_symbolic(): +def test_split_by_indices_n_section_divisible(): # fmt: off @tvm.script.ir_module class Split: @R.function - def main(dumb_param: R.Tensor(("n",)), x: R.Tensor(("m", "n * 3"), "float32")) -> R.Tuple([R.Tensor(("m", "n"), "float32"), R.Tensor(("m", "n"), "float32"), R.Tensor(("m", "n"), "float32")]): - m = T.int64() - n = T.int64() - gv: R.Tuple([R.Tensor((m, n), "float32"), R.Tensor((m, n), "float32"), R.Tensor((m, n), "float32")]) = R.split(x, 3, axis=1) + def main(x: R.Tensor((2, 10, 4), "float32")) -> R.Tuple([R.Tensor((2, 5, 4), "float32"), R.Tensor((2, 5, 4), "float32")]): + gv: R.Tuple([R.Tensor((2, 5, 4), "float32"), R.Tensor((2, 5, 4), "float32")]) = R.split(x, 2, axis=1) return gv @tvm.script.ir_module class Expected: @R.function - def main(dumb_param: R.Tensor(("n",)), x: R.Tensor(("m", "(n * 3)"), "float32")) -> R.Tuple(R.Tensor(("m", "((n * 3) // 3)"), "float32"), R.Tensor(("m", "((((n * 3) // 3) * 2) - ((n * 3) // 3))"), "float32"), R.Tensor(("m", "((n * 3) - (((n * 3) // 3) * 2))"), "float32")): - m = T.int64() - n = T.int64() - gv = R.call_tir(Expected.split, (x,), [R.Tensor((m, ((n * 3) // 3)), "float32"), R.Tensor((m, ((((n * 3) // 3) * 2) - ((n * 3) // 3))), "float32"), R.Tensor((m, ((n * 3) - (((n * 3) // 3) * 2))), "float32")], tir_vars=(n,)) + def main(x: R.Tensor((2, 10, 4), "float32")) -> R.Tuple([R.Tensor((2, 5, 4), "float32"), R.Tensor((2, 5, 4), "float32")]): + gv = R.call_tir(Expected.split, (x,), [R.Tensor((2, 5, 4), "float32"), R.Tensor((2, 5, 4), "float32")]) return gv @T.prim_func(private=True) - def split(var_rxplaceholder: T.handle, var_T_split_sections: T.handle, var_T_split_sections_1: T.handle, var_T_split_sections_2: T.handle, n: T.int64): + def split(rxplaceholder: T.Buffer((T.int64(2), T.int64(10), T.int64(4)), "float32"), T_split_sections: T.Buffer((T.int64(2), T.int64(5), T.int64(4)), "float32"), T_split_sections_1: T.Buffer((T.int64(2), T.int64(5), T.int64(4)), "float32")): T.func_attr({"tir.noalias": True}) - m = T.int64() - rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n * T.int64(3)], dtype="float32") - T_split_sections = T.match_buffer(var_T_split_sections, [m, n * T.int64(3) // T.int64(3)], dtype="float32") - T_split_sections_1 = T.match_buffer(var_T_split_sections_1, [m, n * T.int64(3) // T.int64(3) * T.int64(2) - n * T.int64(3) // T.int64(3)], dtype="float32") - T_split_sections_2 = T.match_buffer(var_T_split_sections_2, [m, n * T.int64(3) - n * T.int64(3) // T.int64(3) * T.int64(2)], dtype="float32") - for i0, i1 in T.grid(m, n): + for i0, i1, i2 in T.grid(T.int64(2), T.int64(5), T.int64(4)): with T.block("T_split_sections"): - ax0, ax1 = T.axis.remap("SS", [i0, i1]) - T.reads(rxplaceholder[ax0, ax1]) - T.writes(T_split_sections[ax0, ax1]) - T_split_sections[ax0, ax1] = rxplaceholder[ax0, ax1] - for i0, i1 in T.grid(m, n): + ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(rxplaceholder[ax0, ax1, ax2]) + T.writes(T_split_sections[ax0, ax1, ax2]) + T_split_sections[ax0, ax1, ax2] = rxplaceholder[ax0, ax1, ax2] + for i0, i1, i2 in T.grid(T.int64(2), T.int64(5), T.int64(4)): with T.block("T_split_sections_1"): - ax0, ax1 = T.axis.remap("SS", [i0, i1]) - T.reads(rxplaceholder[ax0, n + ax1]) - T.writes(T_split_sections_1[ax0, ax1]) - T_split_sections_1[ax0, ax1] = rxplaceholder[ax0, n + ax1] - for i0, i1 in T.grid(m, n): - with T.block("T_split_sections_2"): - ax0, ax1 = T.axis.remap("SS", [i0, i1]) - T.reads(rxplaceholder[ax0, n * T.int64(2) + ax1]) - T.writes(T_split_sections_2[ax0, ax1]) - T_split_sections_2[ax0, ax1] = rxplaceholder[ax0, n * T.int64(2) + ax1] + ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(rxplaceholder[ax0, ax1 + T.int64(5), ax2]) + T.writes(T_split_sections_1[ax0, ax1, ax2]) + T_split_sections_1[ax0, ax1, ax2] = rxplaceholder[ax0, ax1 + T.int64(5), ax2] # fmt: on mod = LegalizeOps()(Split) tvm.ir.assert_structural_equal(mod, Expected) +# TODO uncomment +# def test_split_by_indices_n_section_divisible_symbolic(): +# # fmt: off +# @tvm.script.ir_module +# class Split: +# @R.function +# def main(dumb_param: R.Tensor(("n",)), x: R.Tensor(("m", "n * 3"), "float32")) -> R.Tuple([R.Tensor(("m", "n"), "float32"), R.Tensor(("m", "n"), "float32"), R.Tensor(("m", "n"), "float32")]): +# m = T.int64() +# n = T.int64() +# gv: R.Tuple([R.Tensor((m, n), "float32"), R.Tensor((m, n), "float32"), R.Tensor((m, n), "float32")]) = R.split(x, 3, axis=1) +# return gv + +# @tvm.script.ir_module +# class Expected: +# @R.function +# def main(dumb_param: R.Tensor(("n",)), x: R.Tensor(("m", "(n * 3)"), "float32")) -> R.Tuple(R.Tensor(("m", "((n * 3) // 3)"), "float32"), R.Tensor(("m", "((((n * 3) // 3) * 2) - ((n * 3) // 3))"), "float32"), R.Tensor(("m", "((n * 3) - (((n * 3) // 3) * 2))"), "float32")): +# m = T.int64() +# n = T.int64() +# gv = R.call_tir(Expected.split, (x,), [R.Tensor((m, ((n * 3) // 3)), "float32"), R.Tensor((m, ((((n * 3) // 3) * 2) - ((n * 3) // 3))), "float32"), R.Tensor((m, ((n * 3) - (((n * 3) // 3) * 2))), "float32")], tir_vars=(n,)) +# return gv + +# @T.prim_func(private=True) +# def split(var_rxplaceholder: T.handle, var_T_split_sections: T.handle, var_T_split_sections_1: T.handle, var_T_split_sections_2: T.handle, n: T.int64): +# T.func_attr({"tir.noalias": True}) +# m = T.int64() +# rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n * T.int64(3)], dtype="float32") +# T_split_sections = T.match_buffer(var_T_split_sections, [m, n * T.int64(3) // T.int64(3)], dtype="float32") +# T_split_sections_1 = T.match_buffer(var_T_split_sections_1, [m, n * T.int64(3) // T.int64(3) * T.int64(2) - n * T.int64(3) // T.int64(3)], dtype="float32") +# T_split_sections_2 = T.match_buffer(var_T_split_sections_2, [m, n * T.int64(3) - n * T.int64(3) // T.int64(3) * T.int64(2)], dtype="float32") +# for i0, i1 in T.grid(m, n): +# with T.block("T_split_sections"): +# ax0, ax1 = T.axis.remap("SS", [i0, i1]) +# T.reads(rxplaceholder[ax0, ax1]) +# T.writes(T_split_sections[ax0, ax1]) +# T_split_sections[ax0, ax1] = rxplaceholder[ax0, ax1] +# for i0, i1 in T.grid(m, n): +# with T.block("T_split_sections_1"): +# ax0, ax1 = T.axis.remap("SS", [i0, i1]) +# T.reads(rxplaceholder[ax0, n + ax1]) +# T.writes(T_split_sections_1[ax0, ax1]) +# T_split_sections_1[ax0, ax1] = rxplaceholder[ax0, n + ax1] +# for i0, i1 in T.grid(m, n): +# with T.block("T_split_sections_2"): +# ax0, ax1 = T.axis.remap("SS", [i0, i1]) +# T.reads(rxplaceholder[ax0, n * T.int64(2) + ax1]) +# T.writes(T_split_sections_2[ax0, ax1]) +# T_split_sections_2[ax0, ax1] = rxplaceholder[ax0, n * T.int64(2) + ax1] +# # fmt: on + +# mod = LegalizeOps()(Split) +# tvm.ir.assert_structural_equal(mod, Expected) + def test_squeeze(): # fmt: off From 40f171183e9a8194781ebe2d7592ea377756b0bd Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Thu, 20 Mar 2025 11:34:23 -0400 Subject: [PATCH 13/18] fixed second test and lint --- .../test_transform_legalize_ops_manipulate.py | 105 +++++++++--------- 1 file changed, 53 insertions(+), 52 deletions(-) diff --git a/tests/python/relax/test_transform_legalize_ops_manipulate.py b/tests/python/relax/test_transform_legalize_ops_manipulate.py index db1b5c08ceca..4836ffd01041 100644 --- a/tests/python/relax/test_transform_legalize_ops_manipulate.py +++ b/tests/python/relax/test_transform_legalize_ops_manipulate.py @@ -16,7 +16,8 @@ # under the License. import sys -sys.path.append('/ssd1/htalendr/tvm/python') + +sys.path.append("/ssd1/htalendr/tvm/python") import tvm from tvm import relax @@ -865,57 +866,57 @@ def split(rxplaceholder: T.Buffer((T.int64(2), T.int64(10), T.int64(4)), "float3 mod = LegalizeOps()(Split) tvm.ir.assert_structural_equal(mod, Expected) -# TODO uncomment -# def test_split_by_indices_n_section_divisible_symbolic(): -# # fmt: off -# @tvm.script.ir_module -# class Split: -# @R.function -# def main(dumb_param: R.Tensor(("n",)), x: R.Tensor(("m", "n * 3"), "float32")) -> R.Tuple([R.Tensor(("m", "n"), "float32"), R.Tensor(("m", "n"), "float32"), R.Tensor(("m", "n"), "float32")]): -# m = T.int64() -# n = T.int64() -# gv: R.Tuple([R.Tensor((m, n), "float32"), R.Tensor((m, n), "float32"), R.Tensor((m, n), "float32")]) = R.split(x, 3, axis=1) -# return gv - -# @tvm.script.ir_module -# class Expected: -# @R.function -# def main(dumb_param: R.Tensor(("n",)), x: R.Tensor(("m", "(n * 3)"), "float32")) -> R.Tuple(R.Tensor(("m", "((n * 3) // 3)"), "float32"), R.Tensor(("m", "((((n * 3) // 3) * 2) - ((n * 3) // 3))"), "float32"), R.Tensor(("m", "((n * 3) - (((n * 3) // 3) * 2))"), "float32")): -# m = T.int64() -# n = T.int64() -# gv = R.call_tir(Expected.split, (x,), [R.Tensor((m, ((n * 3) // 3)), "float32"), R.Tensor((m, ((((n * 3) // 3) * 2) - ((n * 3) // 3))), "float32"), R.Tensor((m, ((n * 3) - (((n * 3) // 3) * 2))), "float32")], tir_vars=(n,)) -# return gv - -# @T.prim_func(private=True) -# def split(var_rxplaceholder: T.handle, var_T_split_sections: T.handle, var_T_split_sections_1: T.handle, var_T_split_sections_2: T.handle, n: T.int64): -# T.func_attr({"tir.noalias": True}) -# m = T.int64() -# rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n * T.int64(3)], dtype="float32") -# T_split_sections = T.match_buffer(var_T_split_sections, [m, n * T.int64(3) // T.int64(3)], dtype="float32") -# T_split_sections_1 = T.match_buffer(var_T_split_sections_1, [m, n * T.int64(3) // T.int64(3) * T.int64(2) - n * T.int64(3) // T.int64(3)], dtype="float32") -# T_split_sections_2 = T.match_buffer(var_T_split_sections_2, [m, n * T.int64(3) - n * T.int64(3) // T.int64(3) * T.int64(2)], dtype="float32") -# for i0, i1 in T.grid(m, n): -# with T.block("T_split_sections"): -# ax0, ax1 = T.axis.remap("SS", [i0, i1]) -# T.reads(rxplaceholder[ax0, ax1]) -# T.writes(T_split_sections[ax0, ax1]) -# T_split_sections[ax0, ax1] = rxplaceholder[ax0, ax1] -# for i0, i1 in T.grid(m, n): -# with T.block("T_split_sections_1"): -# ax0, ax1 = T.axis.remap("SS", [i0, i1]) -# T.reads(rxplaceholder[ax0, n + ax1]) -# T.writes(T_split_sections_1[ax0, ax1]) -# T_split_sections_1[ax0, ax1] = rxplaceholder[ax0, n + ax1] -# for i0, i1 in T.grid(m, n): -# with T.block("T_split_sections_2"): -# ax0, ax1 = T.axis.remap("SS", [i0, i1]) -# T.reads(rxplaceholder[ax0, n * T.int64(2) + ax1]) -# T.writes(T_split_sections_2[ax0, ax1]) -# T_split_sections_2[ax0, ax1] = rxplaceholder[ax0, n * T.int64(2) + ax1] -# # fmt: on - -# mod = LegalizeOps()(Split) -# tvm.ir.assert_structural_equal(mod, Expected) + +def test_split_by_indices_n_section_divisible_symbolic(): + # fmt: off + @tvm.script.ir_module + class Split: + @R.function + def main(dumb_param: R.Tensor(("n",)), x: R.Tensor(("m", "n * 3"), "float32")) -> R.Tuple([R.Tensor(("m", "n"), "float32"), R.Tensor(("m", "n"), "float32"), R.Tensor(("m", "n"), "float32")]): + m = T.int64() + n = T.int64() + gv: R.Tuple([R.Tensor((m, n), "float32"), R.Tensor((m, n), "float32"), R.Tensor((m, n), "float32")]) = R.split(x, 3, axis=1) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(dumb_param: R.Tensor(("n",)), x: R.Tensor(("m", "(n * 3)"), "float32")) -> R.Tuple(R.Tensor(("m", "((n * 3) // 3)"), "float32"), R.Tensor(("m", "((((n * 3) // 3) * 2) - ((n * 3) // 3))"), "float32"), R.Tensor(("m", "((n * 3) - (((n * 3) // 3) * 2))"), "float32")): + m = T.int64() + n = T.int64() + gv = R.call_tir(Expected.split, (x,), [R.Tensor((m, ((n * 3 + 3 - 1) // 3)), "float32"), R.Tensor((m, ((((n * 3 + 3 - 1) // 3) * 2) - ((n * 3 + 3 - 1) // 3))), "float32"), R.Tensor((m, ((n * 3) - (((n * 3 + 3 - 1) // 3) * 2))), "float32")], tir_vars=R.shape([n])) + return gv + + @T.prim_func(private=True) + def split(var_rxplaceholder: T.handle, var_T_split_sections: T.handle, var_T_split_sections_1: T.handle, var_T_split_sections_2: T.handle, n: T.int64): + T.func_attr({"tir.noalias": True}) + m = T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n * T.int64(3)], dtype="float32") + T_split_sections = T.match_buffer(var_T_split_sections, [m, (n * T.int64(3) + T.int64(3) - T.int64(1)) // T.int64(3)], dtype="float32") + T_split_sections_1 = T.match_buffer(var_T_split_sections_1, [m, (n * T.int64(3) + T.int64(3) - T.int64(1)) // T.int64(3) * T.int64(2) - (n * T.int64(3) + T.int64(3) - T.int64(1)) // T.int64(3)], dtype="float32") + T_split_sections_2 = T.match_buffer(var_T_split_sections_2, [m, n * T.int64(3) - (n * T.int64(3) + T.int64(3) - T.int64(1)) // T.int64(3) * T.int64(2)], dtype="float32") + for i0, i1 in T.grid(m, n): + with T.block("T_split_sections"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[ax0, ax1]) + T.writes(T_split_sections[ax0, ax1]) + T_split_sections[ax0, ax1] = rxplaceholder[ax0, ax1] + for i0, i1 in T.grid(m, n): + with T.block("T_split_sections_1"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[ax0, ax1 + n]) + T.writes(T_split_sections_1[ax0, ax1]) + T_split_sections_1[ax0, ax1] = rxplaceholder[ax0, ax1 + n] + for i0, i1 in T.grid(m, n): + with T.block("T_split_sections_2"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[ax0, n * T.int64(2) + ax1]) + T.writes(T_split_sections_2[ax0, ax1]) + T_split_sections_2[ax0, ax1] = rxplaceholder[ax0, n * T.int64(2) + ax1] + # fmt: on + + mod = LegalizeOps()(Split) + tvm.ir.assert_structural_equal(mod, Expected) def test_squeeze(): From 4582c3a7e94e6dbe5c8a86e351a3e5e8ec8f505c Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Mon, 24 Mar 2025 12:25:20 -0400 Subject: [PATCH 14/18] linting --- tests/python/relax/test_from_exported_to_cuda.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index c28a2ff63f1f..12b50b3886f3 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -335,6 +335,7 @@ def forward(self, x): torch_module = SplitModelSectionsList(split_size=sections, dim=dim).eval() + @tvm.testing.parametrize_targets("cuda") def test_chunk(target, dev): batch = 3 From afa793a9fe791ae729d71ac895443820e588fa5e Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Mon, 24 Mar 2025 12:25:59 -0400 Subject: [PATCH 15/18] fix one test --- tests/python/relax/test_from_exported_to_cuda.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index 12b50b3886f3..5c7f30d14819 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -334,6 +334,7 @@ def forward(self, x): return torch.split(x, split_size_or_sections=self.split_size, dim=self.dim) torch_module = SplitModelSectionsList(split_size=sections, dim=dim).eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) @tvm.testing.parametrize_targets("cuda") From d71b51855b9053df98bc0f99b09af85d78b8085d Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Mon, 24 Mar 2025 14:20:06 -0400 Subject: [PATCH 16/18] chunk not passing anymore --- tests/python/relax/test_from_exported_to_cuda.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index 5c7f30d14819..8b788bf1628e 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -341,9 +341,9 @@ def forward(self, x): def test_chunk(target, dev): batch = 3 channels = 5 - height = 7 - width = 11 - chunks = 2 + height = 2 + width = 7 + chunks = 11 dim = 1 raw_data = np.random.rand(batch, channels, height, width).astype("float32") From e95aef67053c1093f2267c064beae4e1ca2beff9 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Mon, 24 Mar 2025 14:23:34 -0400 Subject: [PATCH 17/18] get_item error --- tests/python/relax/test_from_exported_to_cuda.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index 8b788bf1628e..11e55919e139 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -14,6 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import sys +sys.path.append('/ssd1/htalendr/tvm/python') import tvm from tvm import relax From bc504461c8c072b0455d2871b1b73913bfe9d2ed Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Mon, 24 Mar 2025 14:55:46 -0400 Subject: [PATCH 18/18] chunk unit tests --- .../torch/base_fx_graph_translator.py | 9 ++- .../relax/test_from_exported_to_cuda.py | 59 +++++++++++++++++-- 2 files changed, 59 insertions(+), 9 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 176b42ce997b..2d73c8dced5e 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -960,9 +960,12 @@ def _chunk(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] chunks = node.args[1] dim = node.args[2] if len(node.args) > 2 else node.kwargs.get("dim", 0) - length_dim = int(self.shape_of(x)[dim]) - n_section = math.ceil(length_dim / chunks) - return self.block_builder.emit(relax.op.split(x, n_section, dim)) + x_shape = self.shape_of(x) + max_chunks = x_shape[dim].value + n_sections = min(chunks, max_chunks) + return self.block_builder.emit( + relax.op.split(x=x, indices_or_sections=n_sections, axis=dim) + ) def _cumsum(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index 11e55919e139..77d0db0b4a99 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -14,8 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import sys -sys.path.append('/ssd1/htalendr/tvm/python') import tvm from tvm import relax @@ -340,12 +338,61 @@ def forward(self, x): @tvm.testing.parametrize_targets("cuda") -def test_chunk(target, dev): - batch = 3 +def test_chunk_even(target, dev): + # Chunks is a divisor of the dimension size + batch = 6 + channels = 2 + height = 3 + width = 4 + chunks = 3 + dim = 0 + raw_data = np.random.rand(batch, channels, height, width).astype("float32") + + class ChunkModel(nn.Module): + def __init__(self, chunks, dim): + super().__init__() + self.chunks = chunks + self.dim = dim + + def forward(self, x): + return x.chunk(self.chunks, dim=self.dim) + + torch_module = ChunkModel(chunks=chunks, dim=dim).eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + +@tvm.testing.parametrize_targets("cuda") +def test_chunk_uneven(target, dev): + # Chunks is not a divisor of the dimension size + batch = 2 channels = 5 + height = 4 + width = 5 + chunks = 2 + dim = 1 + raw_data = np.random.rand(batch, channels, height, width).astype("float32") + + class ChunkModel(nn.Module): + def __init__(self, chunks, dim): + super().__init__() + self.chunks = chunks + self.dim = dim + + def forward(self, x): + return x.chunk(self.chunks, dim=self.dim) + + torch_module = ChunkModel(chunks=chunks, dim=dim).eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + +@tvm.testing.parametrize_targets("cuda") +def test_chunk_too_many(target, dev): + # If user asks for more chunks than the size of the dim, pytorch simply splits in sections of size 1 + batch = 1 + channels = 3 height = 2 - width = 7 - chunks = 11 + width = 2 + chunks = 99 dim = 1 raw_data = np.random.rand(batch, channels, height, width).astype("float32")