diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index faacd2ce5760..762148dcfac3 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -575,8 +575,9 @@ inline Tensor stack(const Array& inputs, int axis = 0, std::string name * * \return A Tensor whose op member is the split operation */ -inline Array split(const Tensor& x, Array split_indices, int axis, - std::string name = "T_split", std::string tag = kInjective) { +inline Array split_indices_array(const Tensor& x, Array split_indices, int axis, + std::string name = "T_split", + std::string tag = kInjective) { if (axis < 0) { axis += static_cast(x->shape.size()); } @@ -968,9 +969,9 @@ inline Tensor strided_slice(const Tensor& x, const Array& begin, const * * \return A Tensor whose op member is the split operation */ -inline Array split_sections(const Tensor& x, int num_sections, int axis, - std::string name = "T_split_sections", - std::string tag = kInjective) { +inline Array split_n_sections(const Tensor& x, int num_sections, int axis, + std::string name = "T_split_sections", + std::string tag = kInjective) { if (axis < 0) { axis += static_cast(x->shape.size()); } @@ -980,14 +981,8 @@ inline Array split_sections(const Tensor& x, int num_sections, int axis, ICHECK_GT(num_sections, 0) << "Slice count must be > 0"; - if (auto node = src_axis_size.as()) { - ICHECK_EQ(node->value % num_sections, 0) - << "num_sections must be an integer factor of the size of axis " << axis << " (" - << node->value << ")"; - } - Array split_indices; - auto seg_size = indexdiv(src_axis_size, num_sections); + auto seg_size = indexdiv(src_axis_size + num_sections - 1, num_sections); for (int i = 0; i < num_sections; ++i) { // region at index 0 is added by split() if (i != 0) { @@ -995,7 +990,7 @@ inline Array split_sections(const Tensor& x, int num_sections, int axis, } } - return split(x, split_indices, axis, name, tag); + return split_indices_array(x, split_indices, axis, name, tag); } /*! diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index bc7a4c4cb046..2abc0b024871 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -329,6 +329,7 @@ def create_convert_map( "select.int": self._select, "slice.Tensor": self._slice, "split.Tensor": self._split, + "split_with_sizes.default": self._split, "squeeze.default": self._squeeze, "squeeze.dim": self._squeeze, "take.default": self._take, diff --git a/python/tvm/relax/transform/legalize_ops/manipulate.py b/python/tvm/relax/transform/legalize_ops/manipulate.py index c71a41dc1c2d..662d4e946b5f 100644 --- a/python/tvm/relax/transform/legalize_ops/manipulate.py +++ b/python/tvm/relax/transform/legalize_ops/manipulate.py @@ -16,7 +16,6 @@ # under the License. # pylint: disable=invalid-name """Default legalization function for manipulate operators.""" -import logging from typing import Optional import tvm @@ -109,16 +108,6 @@ def _permute_dims(bb: BlockBuilder, call: Call) -> Expr: def _split(bb: BlockBuilder, call: Call) -> Expr: if isinstance(call.attrs.indices_or_sections, tir.IntImm): indices_or_sections = call.attrs.indices_or_sections.value - modulo = tvm.arith.Analyzer().simplify( - call.args[0].struct_info.shape.values[call.attrs.axis] % indices_or_sections - ) - if isinstance(modulo, tir.IntImm): - if modulo != 0: - logging.info( - "Split cannot be legalized by TOPI when the axis being split has " - "length that not divisible by the input number of section." - ) - return call else: indices_or_sections = call.attrs.indices_or_sections return bb.call_te(topi.split, call.args[0], indices_or_sections, call.attrs.axis) diff --git a/src/topi/transform.cc b/src/topi/transform.cc index 2e0fde3b289f..7ef63a9b3f56 100644 --- a/src/topi/transform.cc +++ b/src/topi/transform.cc @@ -84,9 +84,9 @@ TVM_REGISTER_GLOBAL("topi.ndarray_size").set_body([](TVMArgs args, TVMRetValue* TVM_REGISTER_GLOBAL("topi.split").set_body([](TVMArgs args, TVMRetValue* rv) { if (args[1].type_code() == kDLInt || args[1].type_code() == kDLUInt) { - *rv = split_sections(args[0], args[1], args[2]); + *rv = split_n_sections(args[0], args[1], args[2]); } else { - *rv = split(args[0], args[1], args[2]); + *rv = split_indices_array(args[0], args[1], args[2]); } }); diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index 6cc12370d648..c120eb89811c 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. + import tvm from tvm import relax import tvm.testing @@ -50,10 +51,17 @@ def assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, tar gpu_params = [tvm.nd.array(p, dev) for p in tvm_params["main"]] gpu_out = vm["main"](gpu_data, *gpu_params) - pytorch_out = torch_module(torch_data).detach().numpy() - actual = gpu_out[0].numpy() - desired = pytorch_out - np.testing.assert_allclose(actual=actual, desired=desired, rtol=1e-5, atol=1e-5) + pytorch_out = torch_module(torch_data) + + if isinstance(pytorch_out, tuple): + for i in range(len(pytorch_out)): + actual = gpu_out[i].numpy() + desired = pytorch_out[i].detach().numpy() + np.testing.assert_allclose(actual=actual, desired=desired, rtol=1e-5, atol=1e-5) + else: + actual = gpu_out[0].numpy() + desired = pytorch_out.detach().numpy() + np.testing.assert_allclose(actual=actual, desired=desired, rtol=1e-5, atol=1e-5) @tvm.testing.parametrize_targets("cuda") @@ -281,5 +289,55 @@ def forward(self, x): assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module3, target, dev) +@tvm.testing.parametrize_targets("cuda") +def test_split_size(target, dev): + # Test split using the split_size argument such that it is not a divisor + # of the dimension to split (the last tensor will be smaller) + batch = 2 + channels = 7 + height, width = 2, 2 + split_size = 3 # last tensor will have just 1 element + dim = 1 # split across channels + raw_data = np.random.rand(batch, channels, height, width).astype("float32") + + class SplitModelSplitSize(nn.Module): + def __init__(self, split_size, dim): + super().__init__() + self.split_size = split_size + self.dim = dim + + def forward(self, x): + return torch.split(x, split_size_or_sections=self.split_size, dim=self.dim) + + torch_module = SplitModelSplitSize(split_size=split_size, dim=dim).eval() + + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + +@tvm.testing.parametrize_targets("cuda") +def test_split_sections_list(target, dev): + # Test split using a list of section sizes + batch = 3 + channels = 2 + height = 10 + width = 5 + sections = [3, 2, 5] + dim = 2 # split across height + raw_data = np.random.rand(batch, channels, height, width).astype("float32") + + class SplitModelSectionsList(nn.Module): + def __init__(self, split_size, dim): + super().__init__() + self.split_size = split_size + self.dim = dim + + def forward(self, x): + return torch.split(x, split_size_or_sections=self.split_size, dim=self.dim) + + torch_module = SplitModelSectionsList(split_size=sections, dim=dim).eval() + + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/test_transform_legalize_ops_manipulate.py b/tests/python/relax/test_transform_legalize_ops_manipulate.py index 0565b7a5790a..4836ffd01041 100644 --- a/tests/python/relax/test_transform_legalize_ops_manipulate.py +++ b/tests/python/relax/test_transform_legalize_ops_manipulate.py @@ -15,6 +15,10 @@ # specific language governing permissions and limitations # under the License. +import sys + +sys.path.append("/ssd1/htalendr/tvm/python") + import tvm from tvm import relax from tvm.relax.transform import LegalizeOps @@ -788,12 +792,42 @@ def test_split_by_indices_n_section_indivisible(): class Split: @R.function def main(x: R.Tensor((2, 10, 4), "float32")) -> R.Tuple([R.Tensor((2, 4, 4), "float32"), R.Tensor((2, 4, 4), "float32"), R.Tensor((2, 2, 4), "float32")]): - gv: R.Tuple([R.Tensor((2, 4, 4), "float32"), R.Tensor((2, 4, 4), "float32"), R.Tensor((2, 2, 4), "float32")]) = R.split(x, 3, axis=1) + gv: R.Tuple([R.Tensor((2, 4, 4), "float32"), R.Tensor((2, 4, 4), "float32"), R.Tensor((2, 2, 4), "float32")]) = R.split(x, indices_or_sections=3, axis=1) return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 10, 4), "float32")) -> R.Tuple([R.Tensor((2, 4, 4), "float32"), R.Tensor((2, 4, 4), "float32"), R.Tensor((2, 2, 4), "float32")]): + gv = R.call_tir(Expected.split, (x,), [R.Tensor((2, 4, 4), "float32"), R.Tensor((2, 4, 4), "float32"), R.Tensor((2, 2, 4), "float32")]) + return gv + + @T.prim_func(private=True) + def split(rxplaceholder: T.Buffer((T.int64(2), T.int64(10), T.int64(4)), "float32"), T_split_sections: T.Buffer((T.int64(2), T.int64(4), T.int64(4)), "float32"), T_split_sections_1: T.Buffer((T.int64(2), T.int64(4), T.int64(4)), "float32"), T_split_sections_2: T.Buffer((T.int64(2), T.int64(2), T.int64(4)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i1, i2 in T.grid(T.int64(2), T.int64(4), T.int64(4)): + with T.block("T_split_sections"): + ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(rxplaceholder[ax0, ax1, ax2]) + T.writes(T_split_sections[ax0, ax1, ax2]) + T_split_sections[ax0, ax1, ax2] = rxplaceholder[ax0, ax1, ax2] + for i0, i1, i2 in T.grid(T.int64(2), T.int64(4), T.int64(4)): + with T.block("T_split_sections_1"): + ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(rxplaceholder[ax0, ax1 + T.int64(4), ax2]) + T.writes(T_split_sections_1[ax0, ax1, ax2]) + T_split_sections_1[ax0, ax1, ax2] = rxplaceholder[ax0, ax1 + T.int64(4), ax2] + for i0, i1, i2 in T.grid(T.int64(2), T.int64(2), T.int64(4)): + with T.block("T_split_sections_2"): + ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(rxplaceholder[ax0, ax1 + T.int64(8), ax2]) + T.writes(T_split_sections_2[ax0, ax1, ax2]) + T_split_sections_2[ax0, ax1, ax2] = rxplaceholder[ax0, ax1 + T.int64(8), ax2] + # fmt: on mod = LegalizeOps()(Split) - tvm.ir.assert_structural_equal(mod, Split) + tvm.ir.assert_structural_equal(mod, Expected) def test_split_by_indices_n_section_divisible(): @@ -850,7 +884,7 @@ class Expected: def main(dumb_param: R.Tensor(("n",)), x: R.Tensor(("m", "(n * 3)"), "float32")) -> R.Tuple(R.Tensor(("m", "((n * 3) // 3)"), "float32"), R.Tensor(("m", "((((n * 3) // 3) * 2) - ((n * 3) // 3))"), "float32"), R.Tensor(("m", "((n * 3) - (((n * 3) // 3) * 2))"), "float32")): m = T.int64() n = T.int64() - gv = R.call_tir(Expected.split, (x,), [R.Tensor((m, ((n * 3) // 3)), "float32"), R.Tensor((m, ((((n * 3) // 3) * 2) - ((n * 3) // 3))), "float32"), R.Tensor((m, ((n * 3) - (((n * 3) // 3) * 2))), "float32")], tir_vars=(n,)) + gv = R.call_tir(Expected.split, (x,), [R.Tensor((m, ((n * 3 + 3 - 1) // 3)), "float32"), R.Tensor((m, ((((n * 3 + 3 - 1) // 3) * 2) - ((n * 3 + 3 - 1) // 3))), "float32"), R.Tensor((m, ((n * 3) - (((n * 3 + 3 - 1) // 3) * 2))), "float32")], tir_vars=R.shape([n])) return gv @T.prim_func(private=True) @@ -858,9 +892,9 @@ def split(var_rxplaceholder: T.handle, var_T_split_sections: T.handle, var_T_spl T.func_attr({"tir.noalias": True}) m = T.int64() rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n * T.int64(3)], dtype="float32") - T_split_sections = T.match_buffer(var_T_split_sections, [m, n * T.int64(3) // T.int64(3)], dtype="float32") - T_split_sections_1 = T.match_buffer(var_T_split_sections_1, [m, n * T.int64(3) // T.int64(3) * T.int64(2) - n * T.int64(3) // T.int64(3)], dtype="float32") - T_split_sections_2 = T.match_buffer(var_T_split_sections_2, [m, n * T.int64(3) - n * T.int64(3) // T.int64(3) * T.int64(2)], dtype="float32") + T_split_sections = T.match_buffer(var_T_split_sections, [m, (n * T.int64(3) + T.int64(3) - T.int64(1)) // T.int64(3)], dtype="float32") + T_split_sections_1 = T.match_buffer(var_T_split_sections_1, [m, (n * T.int64(3) + T.int64(3) - T.int64(1)) // T.int64(3) * T.int64(2) - (n * T.int64(3) + T.int64(3) - T.int64(1)) // T.int64(3)], dtype="float32") + T_split_sections_2 = T.match_buffer(var_T_split_sections_2, [m, n * T.int64(3) - (n * T.int64(3) + T.int64(3) - T.int64(1)) // T.int64(3) * T.int64(2)], dtype="float32") for i0, i1 in T.grid(m, n): with T.block("T_split_sections"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) @@ -870,9 +904,9 @@ def split(var_rxplaceholder: T.handle, var_T_split_sections: T.handle, var_T_spl for i0, i1 in T.grid(m, n): with T.block("T_split_sections_1"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) - T.reads(rxplaceholder[ax0, n + ax1]) + T.reads(rxplaceholder[ax0, ax1 + n]) T.writes(T_split_sections_1[ax0, ax1]) - T_split_sections_1[ax0, ax1] = rxplaceholder[ax0, n + ax1] + T_split_sections_1[ax0, ax1] = rxplaceholder[ax0, ax1 + n] for i0, i1 in T.grid(m, n): with T.block("T_split_sections_2"): ax0, ax1 = T.axis.remap("SS", [i0, i1])