diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 83ee1d3377f4..9406c3b2ea9b 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1594,28 +1594,11 @@ def chunk(self, inputs, input_types): else: unif_size = int(dim / num_chunks) - chunks = [] - for i in range(0, dim, unif_size): - begin = [0] * len(shape) - end = shape[:] - begin[axis] = i - end[axis] = i + unif_size - stride = [1] * len(shape) + indeces = [] + for i in range(unif_size, dim, unif_size): + indeces.append(i) - chunk_out = _op.transform.strided_slice(data, begin=begin, end=end, strides=stride) - chunks.append(chunk_out) - - if dim % num_chunks: - begin = [0] * len(shape) - end = shape[:] - begin[axis] = unif_size * (num_chunks - 1) - end[axis] = dim - stride = [1] * len(shape) - - chunk_out = _op.transform.strided_slice(data, begin=begin, end=end, strides=stride) - chunks.append(chunk_out) - - return chunks + return _op.split(data, indeces, axis) def matmul(self, inputs, input_types): @@ -2681,6 +2664,7 @@ def create_convert_map(self): "aten::alpha_dropout": self.dropout, "aten::mean": self.mean, "aten::chunk": self.chunk, + "aten::unsafe_chunk": self.chunk, "aten::matmul": self.matmul, "aten::bmm": self.matmul, "aten::expand": self.expand,