Skip to content
Merged
21 changes: 8 additions & 13 deletions include/tvm/topi/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -575,8 +575,9 @@ inline Tensor stack(const Array<Tensor>& inputs, int axis = 0, std::string name
*
* \return A Tensor whose op member is the split operation
*/
inline Array<Tensor> split(const Tensor& x, Array<PrimExpr> split_indices, int axis,
std::string name = "T_split", std::string tag = kInjective) {
inline Array<Tensor> split_indices_array(const Tensor& x, Array<PrimExpr> split_indices, int axis,
std::string name = "T_split",
std::string tag = kInjective) {
if (axis < 0) {
axis += static_cast<int>(x->shape.size());
}
Expand Down Expand Up @@ -968,9 +969,9 @@ inline Tensor strided_slice(const Tensor& x, const Array<Integer>& begin, const
*
* \return A Tensor whose op member is the split operation
*/
inline Array<Tensor> split_sections(const Tensor& x, int num_sections, int axis,
std::string name = "T_split_sections",
std::string tag = kInjective) {
inline Array<Tensor> split_n_sections(const Tensor& x, int num_sections, int axis,
std::string name = "T_split_sections",
std::string tag = kInjective) {
if (axis < 0) {
axis += static_cast<int>(x->shape.size());
}
Expand All @@ -980,22 +981,16 @@ inline Array<Tensor> split_sections(const Tensor& x, int num_sections, int axis,

ICHECK_GT(num_sections, 0) << "Slice count must be > 0";

if (auto node = src_axis_size.as<IntImmNode>()) {
ICHECK_EQ(node->value % num_sections, 0)
<< "num_sections must be an integer factor of the size of axis " << axis << " ("
<< node->value << ")";
}

Array<PrimExpr> split_indices;
auto seg_size = indexdiv(src_axis_size, num_sections);
auto seg_size = indexdiv(src_axis_size + num_sections - 1, num_sections);
for (int i = 0; i < num_sections; ++i) {
// region at index 0 is added by split()
if (i != 0) {
split_indices.push_back(seg_size * i);
}
}

return split(x, split_indices, axis, name, tag);
return split_indices_array(x, split_indices, axis, name, tag);
}

/*!
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,7 @@ def create_convert_map(
"select.int": self._select,
"slice.Tensor": self._slice,
"split.Tensor": self._split,
"split_with_sizes.default": self._split,
"squeeze.default": self._squeeze,
"squeeze.dim": self._squeeze,
"take.default": self._take,
Expand Down
11 changes: 0 additions & 11 deletions python/tvm/relax/transform/legalize_ops/manipulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
# under the License.
# pylint: disable=invalid-name
"""Default legalization function for manipulate operators."""
import logging
from typing import Optional

import tvm
Expand Down Expand Up @@ -109,16 +108,6 @@ def _permute_dims(bb: BlockBuilder, call: Call) -> Expr:
def _split(bb: BlockBuilder, call: Call) -> Expr:
if isinstance(call.attrs.indices_or_sections, tir.IntImm):
indices_or_sections = call.attrs.indices_or_sections.value
modulo = tvm.arith.Analyzer().simplify(
call.args[0].struct_info.shape.values[call.attrs.axis] % indices_or_sections
)
if isinstance(modulo, tir.IntImm):
if modulo != 0:
logging.info(
"Split cannot be legalized by TOPI when the axis being split has "
"length that not divisible by the input number of section."
)
return call
else:
indices_or_sections = call.attrs.indices_or_sections
return bb.call_te(topi.split, call.args[0], indices_or_sections, call.attrs.axis)
Expand Down
4 changes: 2 additions & 2 deletions src/topi/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,9 @@ TVM_REGISTER_GLOBAL("topi.ndarray_size").set_body([](TVMArgs args, TVMRetValue*

TVM_REGISTER_GLOBAL("topi.split").set_body([](TVMArgs args, TVMRetValue* rv) {
if (args[1].type_code() == kDLInt || args[1].type_code() == kDLUInt) {
*rv = split_sections(args[0], args[1], args[2]);
*rv = split_n_sections(args[0], args[1], args[2]);
} else {
*rv = split(args[0], args[1], args[2]);
*rv = split_indices_array(args[0], args[1], args[2]);
}
});

Expand Down
66 changes: 62 additions & 4 deletions tests/python/relax/test_from_exported_to_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.


import tvm
from tvm import relax
import tvm.testing
Expand Down Expand Up @@ -50,10 +51,17 @@ def assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, tar
gpu_params = [tvm.nd.array(p, dev) for p in tvm_params["main"]]
gpu_out = vm["main"](gpu_data, *gpu_params)

pytorch_out = torch_module(torch_data).detach().numpy()
actual = gpu_out[0].numpy()
desired = pytorch_out
np.testing.assert_allclose(actual=actual, desired=desired, rtol=1e-5, atol=1e-5)
pytorch_out = torch_module(torch_data)

if isinstance(pytorch_out, tuple):
for i in range(len(pytorch_out)):
actual = gpu_out[i].numpy()
desired = pytorch_out[i].detach().numpy()
np.testing.assert_allclose(actual=actual, desired=desired, rtol=1e-5, atol=1e-5)
else:
actual = gpu_out[0].numpy()
desired = pytorch_out.detach().numpy()
np.testing.assert_allclose(actual=actual, desired=desired, rtol=1e-5, atol=1e-5)


@tvm.testing.parametrize_targets("cuda")
Expand Down Expand Up @@ -281,5 +289,55 @@ def forward(self, x):
assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module3, target, dev)


@tvm.testing.parametrize_targets("cuda")
def test_split_size(target, dev):
# Test split using the split_size argument such that it is not a divisor
# of the dimension to split (the last tensor will be smaller)
batch = 2
channels = 7
height, width = 2, 2
split_size = 3 # last tensor will have just 1 element
dim = 1 # split across channels
raw_data = np.random.rand(batch, channels, height, width).astype("float32")

class SplitModelSplitSize(nn.Module):
def __init__(self, split_size, dim):
super().__init__()
self.split_size = split_size
self.dim = dim

def forward(self, x):
return torch.split(x, split_size_or_sections=self.split_size, dim=self.dim)

torch_module = SplitModelSplitSize(split_size=split_size, dim=dim).eval()

assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev)


@tvm.testing.parametrize_targets("cuda")
def test_split_sections_list(target, dev):
# Test split using a list of section sizes
batch = 3
channels = 2
height = 10
width = 5
sections = [3, 2, 5]
dim = 2 # split across height
raw_data = np.random.rand(batch, channels, height, width).astype("float32")

class SplitModelSectionsList(nn.Module):
def __init__(self, split_size, dim):
super().__init__()
self.split_size = split_size
self.dim = dim

def forward(self, x):
return torch.split(x, split_size_or_sections=self.split_size, dim=self.dim)

torch_module = SplitModelSectionsList(split_size=sections, dim=dim).eval()

assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev)


if __name__ == "__main__":
tvm.testing.main()
50 changes: 42 additions & 8 deletions tests/python/relax/test_transform_legalize_ops_manipulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
# specific language governing permissions and limitations
# under the License.

import sys

sys.path.append("/ssd1/htalendr/tvm/python")

import tvm
from tvm import relax
from tvm.relax.transform import LegalizeOps
Expand Down Expand Up @@ -788,12 +792,42 @@ def test_split_by_indices_n_section_indivisible():
class Split:
@R.function
def main(x: R.Tensor((2, 10, 4), "float32")) -> R.Tuple([R.Tensor((2, 4, 4), "float32"), R.Tensor((2, 4, 4), "float32"), R.Tensor((2, 2, 4), "float32")]):
gv: R.Tuple([R.Tensor((2, 4, 4), "float32"), R.Tensor((2, 4, 4), "float32"), R.Tensor((2, 2, 4), "float32")]) = R.split(x, 3, axis=1)
gv: R.Tuple([R.Tensor((2, 4, 4), "float32"), R.Tensor((2, 4, 4), "float32"), R.Tensor((2, 2, 4), "float32")]) = R.split(x, indices_or_sections=3, axis=1)
return gv

@tvm.script.ir_module
class Expected:
@R.function
def main(x: R.Tensor((2, 10, 4), "float32")) -> R.Tuple([R.Tensor((2, 4, 4), "float32"), R.Tensor((2, 4, 4), "float32"), R.Tensor((2, 2, 4), "float32")]):
gv = R.call_tir(Expected.split, (x,), [R.Tensor((2, 4, 4), "float32"), R.Tensor((2, 4, 4), "float32"), R.Tensor((2, 2, 4), "float32")])
return gv

@T.prim_func(private=True)
def split(rxplaceholder: T.Buffer((T.int64(2), T.int64(10), T.int64(4)), "float32"), T_split_sections: T.Buffer((T.int64(2), T.int64(4), T.int64(4)), "float32"), T_split_sections_1: T.Buffer((T.int64(2), T.int64(4), T.int64(4)), "float32"), T_split_sections_2: T.Buffer((T.int64(2), T.int64(2), T.int64(4)), "float32")):
T.func_attr({"tir.noalias": True})
for i0, i1, i2 in T.grid(T.int64(2), T.int64(4), T.int64(4)):
with T.block("T_split_sections"):
ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2])
T.reads(rxplaceholder[ax0, ax1, ax2])
T.writes(T_split_sections[ax0, ax1, ax2])
T_split_sections[ax0, ax1, ax2] = rxplaceholder[ax0, ax1, ax2]
for i0, i1, i2 in T.grid(T.int64(2), T.int64(4), T.int64(4)):
with T.block("T_split_sections_1"):
ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2])
T.reads(rxplaceholder[ax0, ax1 + T.int64(4), ax2])
T.writes(T_split_sections_1[ax0, ax1, ax2])
T_split_sections_1[ax0, ax1, ax2] = rxplaceholder[ax0, ax1 + T.int64(4), ax2]
for i0, i1, i2 in T.grid(T.int64(2), T.int64(2), T.int64(4)):
with T.block("T_split_sections_2"):
ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2])
T.reads(rxplaceholder[ax0, ax1 + T.int64(8), ax2])
T.writes(T_split_sections_2[ax0, ax1, ax2])
T_split_sections_2[ax0, ax1, ax2] = rxplaceholder[ax0, ax1 + T.int64(8), ax2]

# fmt: on

mod = LegalizeOps()(Split)
tvm.ir.assert_structural_equal(mod, Split)
tvm.ir.assert_structural_equal(mod, Expected)


def test_split_by_indices_n_section_divisible():
Expand Down Expand Up @@ -850,17 +884,17 @@ class Expected:
def main(dumb_param: R.Tensor(("n",)), x: R.Tensor(("m", "(n * 3)"), "float32")) -> R.Tuple(R.Tensor(("m", "((n * 3) // 3)"), "float32"), R.Tensor(("m", "((((n * 3) // 3) * 2) - ((n * 3) // 3))"), "float32"), R.Tensor(("m", "((n * 3) - (((n * 3) // 3) * 2))"), "float32")):
m = T.int64()
n = T.int64()
gv = R.call_tir(Expected.split, (x,), [R.Tensor((m, ((n * 3) // 3)), "float32"), R.Tensor((m, ((((n * 3) // 3) * 2) - ((n * 3) // 3))), "float32"), R.Tensor((m, ((n * 3) - (((n * 3) // 3) * 2))), "float32")], tir_vars=(n,))
gv = R.call_tir(Expected.split, (x,), [R.Tensor((m, ((n * 3 + 3 - 1) // 3)), "float32"), R.Tensor((m, ((((n * 3 + 3 - 1) // 3) * 2) - ((n * 3 + 3 - 1) // 3))), "float32"), R.Tensor((m, ((n * 3) - (((n * 3 + 3 - 1) // 3) * 2))), "float32")], tir_vars=R.shape([n]))
return gv

@T.prim_func(private=True)
def split(var_rxplaceholder: T.handle, var_T_split_sections: T.handle, var_T_split_sections_1: T.handle, var_T_split_sections_2: T.handle, n: T.int64):
T.func_attr({"tir.noalias": True})
m = T.int64()
rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n * T.int64(3)], dtype="float32")
T_split_sections = T.match_buffer(var_T_split_sections, [m, n * T.int64(3) // T.int64(3)], dtype="float32")
T_split_sections_1 = T.match_buffer(var_T_split_sections_1, [m, n * T.int64(3) // T.int64(3) * T.int64(2) - n * T.int64(3) // T.int64(3)], dtype="float32")
T_split_sections_2 = T.match_buffer(var_T_split_sections_2, [m, n * T.int64(3) - n * T.int64(3) // T.int64(3) * T.int64(2)], dtype="float32")
T_split_sections = T.match_buffer(var_T_split_sections, [m, (n * T.int64(3) + T.int64(3) - T.int64(1)) // T.int64(3)], dtype="float32")
T_split_sections_1 = T.match_buffer(var_T_split_sections_1, [m, (n * T.int64(3) + T.int64(3) - T.int64(1)) // T.int64(3) * T.int64(2) - (n * T.int64(3) + T.int64(3) - T.int64(1)) // T.int64(3)], dtype="float32")
T_split_sections_2 = T.match_buffer(var_T_split_sections_2, [m, n * T.int64(3) - (n * T.int64(3) + T.int64(3) - T.int64(1)) // T.int64(3) * T.int64(2)], dtype="float32")
for i0, i1 in T.grid(m, n):
with T.block("T_split_sections"):
ax0, ax1 = T.axis.remap("SS", [i0, i1])
Expand All @@ -870,9 +904,9 @@ def split(var_rxplaceholder: T.handle, var_T_split_sections: T.handle, var_T_spl
for i0, i1 in T.grid(m, n):
with T.block("T_split_sections_1"):
ax0, ax1 = T.axis.remap("SS", [i0, i1])
T.reads(rxplaceholder[ax0, n + ax1])
T.reads(rxplaceholder[ax0, ax1 + n])
T.writes(T_split_sections_1[ax0, ax1])
T_split_sections_1[ax0, ax1] = rxplaceholder[ax0, n + ax1]
T_split_sections_1[ax0, ax1] = rxplaceholder[ax0, ax1 + n]
for i0, i1 in T.grid(m, n):
with T.block("T_split_sections_2"):
ax0, ax1 = T.axis.remap("SS", [i0, i1])
Expand Down