diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 13d13ff24c28..20556167c1ce 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1278,10 +1278,15 @@ def _scatter(self, node: fx.Node) -> relax.Var: return self.block_builder.emit(relax.op.scatter_elements(x, index, src, axis=dim)) def _sort(self, node: fx.Node) -> relax.Var: + # torch.sort() returns a tuple of values and indices + # we use argsort to get indices and gather_elements to get values x = self.env[node.args[0]] dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", -1) descending = node.args[2] if len(node.args) > 2 else node.kwargs.get("descending", False) - return self.block_builder.emit(relax.op.sort(x, dim, descending)) + + indices = self.block_builder.emit(relax.op.argsort(x, dim, descending)) + values = self.block_builder.emit(relax.op.gather_elements(x, indices, axis=dim)) + return self.block_builder.emit(relax.Tuple([values, indices])) def _split(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index ed6740a25ef2..f38f353a9eb1 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -431,6 +431,7 @@ def create_convert_map( "roll.default": self._roll, "select.int": self._select, "slice.Tensor": self._slice, + "sort.default": self._sort, "split.Tensor": self._split, "split_with_sizes.default": self._split, "squeeze.default": self._squeeze, diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index 76a4bb203925..6bb35b50b1df 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -208,6 +208,31 @@ def forward(self, x): assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) +@tvm.testing.parametrize_targets("cuda") +def test_sort(target, dev): + raw_data = np.array([[4, 1, 13], [-30, 1, 3], [4, 0, 10]]).astype("float32") + + # Test values + class SortModelValues(nn.Module): + def forward(self, x): + A, _ = torch.sort(x, dim=0, descending=True) + B, _ = torch.sort(x, dim=1, descending=False) + return A + B + + torch_module = SortModelValues().eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + # Test indices + class SortModelIndices(nn.Module): + def forward(self, x): + _, A = torch.sort(x, dim=0, descending=True) + _, B = torch.sort(x, dim=1, descending=False) + return A + B + + torch_module = SortModelIndices().eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + @tvm.testing.parametrize_targets("cuda") def test_tensor_clamp(target, dev): class ClampBothTensor(torch.nn.Module): diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index e8db6af34709..2d27fa1f5921 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -4749,11 +4749,20 @@ def forward(self, x): class Expected: @R.function def main( - inp_0: R.Tensor((5, 3), dtype="float32"), - ) -> R.Tensor((5, 3), dtype="float32"): + inp_0: R.Tensor((5, 3), dtype="float32") + ) -> R.Tuple(R.Tensor((5, 3), dtype="float32"), R.Tensor((5, 3), dtype="int32")): with R.dataflow(): - lv: R.Tensor((5, 3), dtype="float32") = R.sort(inp_0, axis=1, descending=True) - gv: R.Tensor((5, 3), dtype="float32") = lv + lv: R.Tensor((5, 3), dtype="int32") = R.argsort( + inp_0, axis=1, descending=True, dtype="int32" + ) + lv1: R.Tensor((5, 3), dtype="float32") = R.gather_elements(inp_0, lv, axis=1) + lv2: R.Tuple(R.Tensor((5, 3), dtype="float32"), R.Tensor((5, 3), dtype="int32")) = ( + lv1, + lv, + ) + gv: R.Tuple( + R.Tensor((5, 3), dtype="float32"), R.Tensor((5, 3), dtype="int32") + ) = lv2 R.output(gv) return gv