diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index ecd8665b4353..71554a8a5bab 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -956,6 +956,17 @@ def _cat(self, node: fx.Node) -> relax.Var: axis = args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0) return self.block_builder.emit(relax.op.concat(args[0], axis=axis)) + def _chunk(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + chunks = node.args[1] + dim = node.args[2] if len(node.args) > 2 else node.kwargs.get("dim", 0) + x_shape = self.shape_of(x) + max_chunks = x_shape[dim].value + n_sections = min(chunks, max_chunks) + return self.block_builder.emit( + relax.op.split(x=x, indices_or_sections=n_sections, axis=dim) + ) + def _cumsum(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 37caae6c9854..4319fbebe74a 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -359,6 +359,7 @@ def create_convert_map( "argmin.default": self._argmax_argmin(relax.op.argmin), # tensor manipulation "cat.default": self._cat, + "chunk.default": self._chunk, "clamp.Tensor": self._clamp, "concat.default": self._cat, "copy_.default": self._copy_, diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index f7501dd3b5b3..64babdc43a5c 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. - import tvm from tvm import relax import tvm.testing @@ -332,6 +331,30 @@ def forward(self, x): assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) +@tvm.testing.parametrize_targets("cuda") +def test_split_sections_list(target, dev): + # Test split using a list of section sizes + batch = 3 + channels = 2 + height = 10 + width = 5 + sections = [3, 2, 5] + dim = 2 # split across height + raw_data = np.random.rand(batch, channels, height, width).astype("float32") + + class SplitModelSectionsList(nn.Module): + def __init__(self, split_size, dim): + super().__init__() + self.split_size = split_size + self.dim = dim + + def forward(self, x): + return torch.split(x, split_size_or_sections=self.split_size, dim=self.dim) + + torch_module = SplitModelSectionsList(split_size=sections, dim=dim).eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + @tvm.testing.parametrize_targets("cuda") def test_batch_norm0(target, dev): # Eval, no momentum, no affine, no running stats @@ -373,26 +396,74 @@ def test_batch_norm3(target, dev): @tvm.testing.parametrize_targets("cuda") -def test_split_sections_list(target, dev): - # Test split using a list of section sizes - batch = 3 +def test_chunk_even(target, dev): + # Chunks is a divisor of the dimension size + batch = 6 channels = 2 - height = 10 + height = 3 + width = 4 + chunks = 3 + dim = 0 + raw_data = np.random.rand(batch, channels, height, width).astype("float32") + + class ChunkModel(nn.Module): + def __init__(self, chunks, dim): + super().__init__() + self.chunks = chunks + self.dim = dim + + def forward(self, x): + return x.chunk(self.chunks, dim=self.dim) + + torch_module = ChunkModel(chunks=chunks, dim=dim).eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + +@tvm.testing.parametrize_targets("cuda") +def test_chunk_uneven(target, dev): + # Chunks is not a divisor of the dimension size + batch = 2 + channels = 5 + height = 4 width = 5 - sections = [3, 2, 5] - dim = 2 # split across height + chunks = 2 + dim = 1 raw_data = np.random.rand(batch, channels, height, width).astype("float32") - class SplitModelSectionsList(nn.Module): - def __init__(self, split_size, dim): + class ChunkModel(nn.Module): + def __init__(self, chunks, dim): super().__init__() - self.split_size = split_size + self.chunks = chunks self.dim = dim def forward(self, x): - return torch.split(x, split_size_or_sections=self.split_size, dim=self.dim) + return x.chunk(self.chunks, dim=self.dim) - torch_module = SplitModelSectionsList(split_size=sections, dim=dim).eval() + torch_module = ChunkModel(chunks=chunks, dim=dim).eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + +@tvm.testing.parametrize_targets("cuda") +def test_chunk_too_many(target, dev): + # If user asks for more chunks than the size of the dim, pytorch simply splits in sections of size 1 + batch = 1 + channels = 3 + height = 2 + width = 2 + chunks = 99 + dim = 1 + raw_data = np.random.rand(batch, channels, height, width).astype("float32") + + class ChunkModel(nn.Module): + def __init__(self, chunks, dim): + super().__init__() + self.chunks = chunks + self.dim = dim + + def forward(self, x): + return x.chunk(self.chunks, dim=self.dim) + + torch_module = ChunkModel(chunks=chunks, dim=dim).eval() assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev)