diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 26121ecdea10..cc9217c9f5f8 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -368,6 +368,7 @@ def create_convert_map( "argmin.default": self._argmax_argmin(relax.op.argmin), "where.self": self._where, # tensor manipulation + "argsort.default": self._argsort, "cat.default": self._cat, "chunk.default": self._chunk, "clamp.Tensor": self._clamp, @@ -390,6 +391,7 @@ def create_convert_map( "squeeze.dim": self._squeeze, "take.default": self._take, "tile.default": self._tile, + "topk.default": self._topk, "transpose.int": self._transpose, "unsqueeze.default": lambda node: self.block_builder.emit( relax.op.expand_dims(self.env[node.args[0]], node.args[1]) diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index cc2f669d32e0..081b82b3c563 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -3859,5 +3859,54 @@ def main( verify_model(Where(), (condition, x, y), {}, Expected) +def test_argsort(): + class Argsort(Module): + def forward(self, x): + return torch.argsort(x, dim=1, descending=True) + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((5, 3), dtype="float32")) -> R.Tuple(R.Tensor((5, 3), dtype="int32")): + with R.dataflow(): + lv: R.Tensor((5, 3), dtype="int32") = R.argsort( + x, axis=1, descending=True, dtype="int32" + ) + gv: R.Tuple(R.Tensor((5, 3), dtype="int32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(5, 3, dtype=torch.float32),) + verify_model(Argsort(), example_args, {}, Expected) + + +def test_topk(): + class Topk(Module): + def forward(self, x): + return torch.topk(x, k=2, dim=1, largest=True, sorted=True) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((5, 3), dtype="float32") + ) -> R.Tuple(R.Tensor((5, 2), dtype="float32"), R.Tensor((5, 2), dtype="int64")): + with R.dataflow(): + lv: R.Tuple( + R.Tensor((5, 2), dtype="float32"), R.Tensor((5, 2), dtype="int64") + ) = R.topk(x, k=2, axis=1, ret_type="both", largest=True, dtype="int64") + lv1: R.Tensor((5, 2), dtype="float32") = lv[0] + lv2: R.Tensor((5, 2), dtype="int64") = lv[1] + gv: R.Tuple(R.Tensor((5, 2), dtype="float32"), R.Tensor((5, 2), dtype="int64")) = ( + lv1, + lv2, + ) + R.output(gv) + return gv + + example_args = (torch.randn(5, 3, dtype=torch.float32),) + verify_model(Topk(), example_args, {}, Expected) + + if __name__ == "__main__": tvm.testing.main()