From 652dc9e87ebece0bd3f8187b6ec59bbf93dfc116 Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Sun, 6 Apr 2025 16:30:01 +0800 Subject: [PATCH 1/4] Update exported_program_translator.py --- python/tvm/relax/frontend/torch/exported_program_translator.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 26121ecdea10..cc9217c9f5f8 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -368,6 +368,7 @@ def create_convert_map( "argmin.default": self._argmax_argmin(relax.op.argmin), "where.self": self._where, # tensor manipulation + "argsort.default": self._argsort, "cat.default": self._cat, "chunk.default": self._chunk, "clamp.Tensor": self._clamp, @@ -390,6 +391,7 @@ def create_convert_map( "squeeze.dim": self._squeeze, "take.default": self._take, "tile.default": self._tile, + "topk.default": self._topk, "transpose.int": self._transpose, "unsqueeze.default": lambda node: self.block_builder.emit( relax.op.expand_dims(self.env[node.args[0]], node.args[1]) From a0f530e0fdd268629ea995ba39641b10d60c676e Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Sun, 6 Apr 2025 16:31:15 +0800 Subject: [PATCH 2/4] Update test_frontend_from_exported_program.py --- .../test_frontend_from_exported_program.py | 43 +++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index cc2f669d32e0..47ef79bb13e2 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -3858,6 +3858,49 @@ def main( verify_model(Where(), (condition, x, y), {}, Expected) +def test_argsort(): + class Argsort(Module): + def forward(self, x): + return torch.argsort(x, dim=1, descending=True) + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((5, 3), dtype="float32")) -> R.Tuple(R.Tensor((5, 3), dtype="int32")): + with R.dataflow(): + lv: R.Tensor((5, 3), dtype="int32") = R.argsort(x, axis=1, descending=True, dtype="int32") + gv: R.Tuple(R.Tensor((5, 3), dtype="int32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(5, 3, dtype=torch.float32),) + verify_model(Argsort(), example_args, {}, Expected) + + +def test_topk(): + class Topk(Module): + def forward(self, x): + return torch.topk(x, k=2, dim=1, largest=True, sorted=True) + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((5, 3), dtype="float32")) -> R.Tuple(R.Tensor((5, 2), dtype="float32"), + R.Tensor((5, 2), dtype="int64")): + with R.dataflow(): + lv: R.Tuple(R.Tensor((5, 2), dtype="float32"), R.Tensor((5, 2), dtype="int64")) = R.topk(x, k=2, axis=1, + ret_type="both", + largest=True, + dtype="int64") + lv1: R.Tensor((5, 2), dtype="float32") = lv[0] + lv2: R.Tensor((5, 2), dtype="int64") = lv[1] + gv: R.Tuple(R.Tensor((5, 2), dtype="float32"), R.Tensor((5, 2), dtype="int64")) = (lv1, lv2) + R.output(gv) + return gv + + example_args = (torch.randn(5, 3, dtype=torch.float32),) + verify_model(Topk(), example_args, {}, Expected) + if __name__ == "__main__": tvm.testing.main() From 8dd093fc5e474a39d6832743b1f17a897c71083f Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Sun, 6 Apr 2025 23:40:56 +0800 Subject: [PATCH 3/4] Update test_frontend_from_exported_program.py --- .../test_frontend_from_exported_program.py | 22 ++++++++++++------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 47ef79bb13e2..5430b3d7ee12 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -3858,6 +3858,7 @@ def main( verify_model(Where(), (condition, x, y), {}, Expected) + def test_argsort(): class Argsort(Module): def forward(self, x): @@ -3868,7 +3869,9 @@ class Expected: @R.function def main(x: R.Tensor((5, 3), dtype="float32")) -> R.Tuple(R.Tensor((5, 3), dtype="int32")): with R.dataflow(): - lv: R.Tensor((5, 3), dtype="int32") = R.argsort(x, axis=1, descending=True, dtype="int32") + lv: R.Tensor((5, 3), dtype="int32") = R.argsort( + x, axis=1, descending=True, dtype="int32" + ) gv: R.Tuple(R.Tensor((5, 3), dtype="int32")) = (lv,) R.output(gv) return gv @@ -3885,16 +3888,19 @@ def forward(self, x): @tvm.script.ir_module class Expected: @R.function - def main(x: R.Tensor((5, 3), dtype="float32")) -> R.Tuple(R.Tensor((5, 2), dtype="float32"), - R.Tensor((5, 2), dtype="int64")): + def main( + x: R.Tensor((5, 3), dtype="float32") + ) -> R.Tuple(R.Tensor((5, 2), dtype="float32"), R.Tensor((5, 2), dtype="int64")): with R.dataflow(): - lv: R.Tuple(R.Tensor((5, 2), dtype="float32"), R.Tensor((5, 2), dtype="int64")) = R.topk(x, k=2, axis=1, - ret_type="both", - largest=True, - dtype="int64") + lv: R.Tuple( + R.Tensor((5, 2), dtype="float32"), R.Tensor((5, 2), dtype="int64") + ) = R.topk(x, k=2, axis=1, ret_type="both", largest=True, dtype="int64") lv1: R.Tensor((5, 2), dtype="float32") = lv[0] lv2: R.Tensor((5, 2), dtype="int64") = lv[1] - gv: R.Tuple(R.Tensor((5, 2), dtype="float32"), R.Tensor((5, 2), dtype="int64")) = (lv1, lv2) + gv: R.Tuple(R.Tensor((5, 2), dtype="float32"), R.Tensor((5, 2), dtype="int64")) = ( + lv1, + lv2 + ) R.output(gv) return gv From 6966e2ff9b3e6df164e8ada974954759db7ad616 Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Mon, 7 Apr 2025 00:28:37 +0800 Subject: [PATCH 4/4] Update test_frontend_from_exported_program.py --- tests/python/relax/test_frontend_from_exported_program.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 5430b3d7ee12..081b82b3c563 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -3899,7 +3899,7 @@ def main( lv2: R.Tensor((5, 2), dtype="int64") = lv[1] gv: R.Tuple(R.Tensor((5, 2), dtype="float32"), R.Tensor((5, 2), dtype="int64")) = ( lv1, - lv2 + lv2, ) R.output(gv) return gv