From a10e0d935631d607f12b5b4fe25e8905c91d3059 Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Thu, 17 Apr 2025 23:07:29 +0800 Subject: [PATCH 01/20] Update base_fx_graph_translator.py --- .../torch/base_fx_graph_translator.py | 34 ------------------- 1 file changed, 34 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 4c9480b58748..6d880ab90dc2 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -927,40 +927,6 @@ def _mean(self, node: fx.Node) -> relax.Var: keepdim = args[2] if len(node.args) > 2 else node.kwargs.get("keepdim", False) return self.block_builder.emit(relax.op.mean(x, dim, keepdims=keepdim)) - def _norm(self, node: fx.Node) -> relax.Var: - data = self.env[node.args[0]] - dtype = data.struct_info.dtype - order = node.args[1] if len(node.args) > 1 else node.kwargs.get("p", 2) - axis = node.args[2] if len(node.args) > 2 else None - keepdims = node.args[3] if len(node.args) > 3 else False - - if order == float("inf"): - return self.block_builder.emit( - relax.op.max(relax.op.abs(data), axis=axis, keepdims=keepdims) - ) - elif order == float("-inf"): - return self.block_builder.emit( - relax.op.min(relax.op.abs(data), axis=axis, keepdims=keepdims) - ) - # frobenius_norm - elif order == "fro": - return self.block_builder.emit( - relax.op.sqrt( - relax.op.sum(relax.op.multiply(data, data), axis=axis, keepdims=keepdims), - ) - ) - else: - reci_order = relax.const(1 / order, dtype=dtype) - order = relax.const(order, dtype=dtype) - return self.block_builder.emit( - relax.op.power( - relax.op.sum( - relax.op.power(relax.op.abs(data), order), axis=axis, keepdims=keepdims - ), - reci_order, - ) - ) - def _prod(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) x = args[0] From 941aa4dcfc98083c44061407f7fbd9214fe2ebf7 Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Thu, 17 Apr 2025 23:10:06 +0800 Subject: [PATCH 02/20] Update fx_translator.py --- .../tvm/relax/frontend/torch/fx_translator.py | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 297529e8bf29..fc283e42834d 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -414,6 +414,29 @@ def _lerp(self, node: fx.Node) -> relax.Var: relax.op.add(start, relax.op.multiply(weight, relax.op.subtract(end, start))) ) + ########## Statistical ########## + + def _norm(self, node: fx.Node) -> relax.Var: + data = self.env[node.args[0]] + dtype = data.struct_info.dtype + order = node.args[1] if len(node.args) > 1 else node.kwargs.get("p", 2) + axis = node.args[2] if len(node.args) > 2 else None + keepdims = node.args[3] if len(node.args) > 3 else False + + if order == float("inf"): + return self.block_builder.emit(relax.op.max(relax.op.abs(data), axis=axis, keepdims=keepdims)) + elif order == float("-inf"): + return self.block_builder.emit(relax.op.min(relax.op.abs(data), axis=axis, keepdims=keepdims)) + else: + reci_order = relax.const(1 / order, dtype=dtype) + order = relax.const(order, dtype=dtype) + return self.block_builder.emit( + relax.op.power( + relax.op.sum(relax.op.power(relax.op.abs(data), order), axis=axis, keepdims=keepdims), + reci_order, + ) + ) + ########## Manipulation ########## def _chunk(self, node: fx.Node) -> relax.Var: From b31f4b415aee47ff00881892a9bb7108bd3ae7ce Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Thu, 17 Apr 2025 23:10:49 +0800 Subject: [PATCH 03/20] Update exported_program_translator.py --- .../torch/exported_program_translator.py | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index c82a5e2b1100..eddae0870ec0 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -200,6 +200,29 @@ def _upsample_nearest2d(self, node: fx.node) -> relax.Var: align_corners=align_corners, ) + ########## Statistical ########## + + def _norm(self, node: fx.Node) -> relax.Var: + data = self.env[node.args[0]] + dtype = data.struct_info.dtype + order = node.args[1] if len(node.args) > 1 else node.kwargs.get("p", 2) + axis = node.args[2] if len(node.args) > 2 else None + keepdims = node.args[3] if len(node.args) > 3 else False + + if order == float("inf"): + return self.block_builder.emit(relax.op.max(relax.op.abs(data), axis=axis, keepdims=keepdims)) + elif order == float("-inf"): + return self.block_builder.emit(relax.op.min(relax.op.abs(data), axis=axis, keepdims=keepdims)) + else: + reci_order = relax.const(1 / order, dtype=dtype) + order = relax.const(order, dtype=dtype) + return self.block_builder.emit( + relax.op.power( + relax.op.sum(relax.op.power(relax.op.abs(data), order), axis=axis, keepdims=keepdims), + reci_order, + ) + ) + ########## Manipulation ########## def _narrow(self, node: fx.Node) -> relax.Var: From cef6a602726d859ed5ac3ae002b345c434cfb5d2 Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Thu, 17 Apr 2025 23:11:20 +0800 Subject: [PATCH 04/20] Update exported_program_translator.py --- python/tvm/relax/frontend/torch/exported_program_translator.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index eddae0870ec0..ef88b3e601ce 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -410,6 +410,7 @@ def create_convert_map( "upsample_nearest2d.vec": self._upsample_nearest2d, # statistical "mean.dim": self._mean, + "linalg_vector_norm.default": self._norm, "prod.default": self._prod, "std.correction": self._std, "sum.default": self._sum, From e81218635604511cca8bf57ee385256d96fdce9d Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Thu, 17 Apr 2025 23:11:55 +0800 Subject: [PATCH 05/20] Update fx_translator.py --- python/tvm/relax/frontend/torch/fx_translator.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index fc283e42834d..ebabe2e44c0d 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -427,6 +427,11 @@ def _norm(self, node: fx.Node) -> relax.Var: return self.block_builder.emit(relax.op.max(relax.op.abs(data), axis=axis, keepdims=keepdims)) elif order == float("-inf"): return self.block_builder.emit(relax.op.min(relax.op.abs(data), axis=axis, keepdims=keepdims)) + # frobenius_norm + elif order == "fro": + return self.block_builder.emit( + relax.op.sqrt(relax.op.sum(relax.op.multiply(data, data), axis=axis, keepdims=keepdims)) + ) else: reci_order = relax.const(1 / order, dtype=dtype) order = relax.const(order, dtype=dtype) From adcd48cc1a6a8cf23725bb982e2d1572bb13dce9 Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Thu, 17 Apr 2025 23:12:55 +0800 Subject: [PATCH 06/20] Update test_frontend_from_fx.py --- tests/python/relax/test_frontend_from_fx.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index a962de8a3237..41c6b92680e8 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -4626,19 +4626,16 @@ def main( return gv norms = [ - (float("inf"), None, False), - (float("-inf"), None, False), - (float(2), None, False), - (float(1.0), None, False), - (float(-4), None, True), - (float(0.5), None, True), - ("fro", None, False), + ((float('inf'), None, False), Expected1), + ((float('-inf'), None, False), Expected2), + ((float(2), None, False), Expected3), + ((float(1.0), None, False), Expected4), + ((float(-4), None, True), Expected5), + ((float(0.5), None, True), Expected6), + (("fro", None, False), Expected7) ] - for norm, expected in zip( - norms, [Expected1, Expected2, Expected3, Expected4, Expected5, Expected6, Expected7] - ): - p, dim, keepdim = norm + for (p, dim, keepdim), expected in norms: verify_model(Norm(p, dim=dim, keepdim=keepdim), input_info, {}, expected) From a1371246f17020f54d48cf6b797bdf48f972ba88 Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Thu, 17 Apr 2025 23:14:46 +0800 Subject: [PATCH 07/20] Update test_frontend_from_exported_program.py --- .../test_frontend_from_exported_program.py | 112 ++++++++++++++++++ 1 file changed, 112 insertions(+) diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 26d3d3f7bde2..2d213582f1be 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -3932,5 +3932,117 @@ def main( verify_model(Narrow(), example_args, {}, Expected) +def test_norm(): + + class Norm(Module): + def __init__(self, p, dim=None, keepdim=False): + super().__init__() + self.p = p + self.dim = dim + self.keepdim = keepdim + + def forward(self, x): + return torch.norm(x, p=self.p, dim=self.dim, keepdim=self.keepdim) + + @tvm.script.ir_module + class Expected1: + @R.function + def main( + inp_0: R.Tensor((1, 3, 5, 3), dtype="float32"), + ) -> R.Tuple(R.Tensor((), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((), dtype="float32") = R.max(R.abs(inp_0), axis=None, keepdims=False) + gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv,) + R.output(gv) + return gv + + @tvm.script.ir_module + class Expected2: + @R.function + def main( + inp_0: R.Tensor((1, 3, 5, 3), dtype="float32"), + ) -> R.Tuple(R.Tensor((), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((), dtype="float32") = R.min(R.abs(inp_0), axis=None, keepdims=False) + gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv,) + R.output(gv) + return gv + + @tvm.script.ir_module + class Expected3: + @R.function + def main( + inp_0: R.Tensor((1, 3, 5, 3), dtype="float32"), + ) -> R.Tuple(R.Tensor((), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((1, 3, 5, 3), dtype="float32") = R.abs(inp_0) + lv1: R.Tensor((1, 3, 5, 3), dtype="float32") = R.power(lv, R.const(2, "float32")) + lv2: R.Tensor((), dtype="float32") = R.sum(lv1, axis=None, keepdims=False) + lv3: R.Tensor((), dtype="float32") = R.power(lv2, R.const(0.5, "float32")) + gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv3,) + R.output(gv) + return gv + + @tvm.script.ir_module + class Expected4: + @R.function + def main( + inp_0: R.Tensor((1, 3, 5, 3), dtype="float32"), + ) -> R.Tuple(R.Tensor((), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((1, 3, 5, 3), dtype="float32") = R.abs(inp_0) + lv1: R.Tensor((1, 3, 5, 3), dtype="float32") = R.power(lv, R.const(1.0, "float32")) + lv2: R.Tensor((), dtype="float32") = R.sum(lv1, axis=None, keepdims=False) + lv3: R.Tensor((), dtype="float32") = R.power(lv2, R.const(1.0, "float32")) + gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv3,) + R.output(gv) + return gv + + @tvm.script.ir_module + class Expected5: + @R.function + def main( + inp_0: R.Tensor((1, 3, 5, 3), dtype="float32"), + ) -> R.Tuple(R.Tensor((1, 1, 1, 1), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((1, 3, 5, 3), dtype="float32") = R.abs(inp_0) + lv1: R.Tensor((1, 3, 5, 3), dtype="float32") = R.power(lv, R.const(-4.0, "float32")) + lv2: R.Tensor((1, 1, 1, 1), dtype="float32") = R.sum(lv1, axis=None, keepdims=True) + lv3: R.Tensor((1, 1, 1, 1), dtype="float32") = R.power(lv2, R.const(-0.25, "float32")) + gv: R.Tuple(R.Tensor((1, 1, 1, 1), dtype="float32")) = (lv3,) + R.output(gv) + return gv + + @tvm.script.ir_module + class Expected6: + @R.function + def main( + inp_0: R.Tensor((1, 3, 5, 3), dtype="float32"), + ) -> R.Tuple(R.Tensor((1, 1, 1, 1), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((1, 3, 5, 3), dtype="float32") = R.abs(inp_0) + lv1: R.Tensor((1, 3, 5, 3), dtype="float32") = R.power(lv, R.const(0.5, "float32")) + lv2: R.Tensor((1, 1, 1, 1), dtype="float32") = R.sum(lv1, axis=None, keepdims=True) + lv3: R.Tensor((1, 1, 1, 1), dtype="float32") = R.power(lv2, R.const(2.0, "float32")) + gv: R.Tuple(R.Tensor((1, 1, 1, 1), dtype="float32")) = (lv3,) + R.output(gv) + return gv + + + norms = [ + ((float('inf'), None, False), Expected1), + ((float('-inf'), None, False), Expected2), + ((float(2), None, False), Expected3), + ((float(1.0), None, False), Expected4), + ((float(-4), None, True), Expected5), + ((float(0.5), None, True), Expected6), + ] + + example_args = (torch.randn(1, 3, 5, 3, dtype=torch.float32),) + + for (p, dim, keepdim), expected in norms: + verify_model(Norm(p, dim=dim, keepdim=keepdim), example_args, {}, expected) + + if __name__ == "__main__": tvm.testing.main() From 0284a89650b1d60615be2879d32d23ea485daa8f Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Fri, 18 Apr 2025 22:47:59 +0800 Subject: [PATCH 08/20] Update fx_translator.py --- python/tvm/relax/frontend/torch/fx_translator.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index fd7a76394871..fd2bc62f5380 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -432,20 +432,28 @@ def _norm(self, node: fx.Node) -> relax.Var: keepdims = node.args[3] if len(node.args) > 3 else False if order == float("inf"): - return self.block_builder.emit(relax.op.max(relax.op.abs(data), axis=axis, keepdims=keepdims)) + return self.block_builder.emit( + relax.op.max(relax.op.abs(data), axis=axis, keepdims=keepdims) + ) elif order == float("-inf"): - return self.block_builder.emit(relax.op.min(relax.op.abs(data), axis=axis, keepdims=keepdims)) + return self.block_builder.emit( + relax.op.min(relax.op.abs(data), axis=axis, keepdims=keepdims) + ) # frobenius_norm elif order == "fro": return self.block_builder.emit( - relax.op.sqrt(relax.op.sum(relax.op.multiply(data, data), axis=axis, keepdims=keepdims)) + relax.op.sqrt( + relax.op.sum(relax.op.multiply(data, data), axis=axis, keepdims=keepdims) + ) ) else: reci_order = relax.const(1 / order, dtype=dtype) order = relax.const(order, dtype=dtype) return self.block_builder.emit( relax.op.power( - relax.op.sum(relax.op.power(relax.op.abs(data), order), axis=axis, keepdims=keepdims), + relax.op.sum( + relax.op.power(relax.op.abs(data), order), axis=axis, keepdims=keepdims + ), reci_order, ) ) From 9d326cc6f9bbce70fa86f2f8293ae0c1002e6c59 Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Fri, 18 Apr 2025 22:49:29 +0800 Subject: [PATCH 09/20] Update exported_program_translator.py --- .../frontend/torch/exported_program_translator.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index beb8d32d1c25..8f6b7169aeec 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -210,15 +210,21 @@ def _norm(self, node: fx.Node) -> relax.Var: keepdims = node.args[3] if len(node.args) > 3 else False if order == float("inf"): - return self.block_builder.emit(relax.op.max(relax.op.abs(data), axis=axis, keepdims=keepdims)) + return self.block_builder.emit( + relax.op.max(relax.op.abs(data), axis=axis, keepdims=keepdims) + ) elif order == float("-inf"): - return self.block_builder.emit(relax.op.min(relax.op.abs(data), axis=axis, keepdims=keepdims)) + return self.block_builder.emit( + relax.op.min(relax.op.abs(data), axis=axis, keepdims=keepdims) + ) else: reci_order = relax.const(1 / order, dtype=dtype) order = relax.const(order, dtype=dtype) return self.block_builder.emit( relax.op.power( - relax.op.sum(relax.op.power(relax.op.abs(data), order), axis=axis, keepdims=keepdims), + relax.op.sum( + relax.op.power(relax.op.abs(data), order), axis=axis, keepdims=keepdims + ), reci_order, ) ) From 3babec099ccee244c5cec8c61608f7e8e3e9874b Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Fri, 18 Apr 2025 22:52:52 +0800 Subject: [PATCH 10/20] Update test_frontend_from_exported_program.py --- .../relax/test_frontend_from_exported_program.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 1387e00c6b70..8658b3cefde9 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -4378,7 +4378,6 @@ def main( def test_norm(): - class Norm(Module): def __init__(self, p, dim=None, keepdim=False): super().__init__() @@ -4453,7 +4452,9 @@ def main( lv: R.Tensor((1, 3, 5, 3), dtype="float32") = R.abs(inp_0) lv1: R.Tensor((1, 3, 5, 3), dtype="float32") = R.power(lv, R.const(-4.0, "float32")) lv2: R.Tensor((1, 1, 1, 1), dtype="float32") = R.sum(lv1, axis=None, keepdims=True) - lv3: R.Tensor((1, 1, 1, 1), dtype="float32") = R.power(lv2, R.const(-0.25, "float32")) + lv3: R.Tensor((1, 1, 1, 1), dtype="float32") = R.power( + lv2, R.const(-0.25, "float32") + ) gv: R.Tuple(R.Tensor((1, 1, 1, 1), dtype="float32")) = (lv3,) R.output(gv) return gv @@ -4473,10 +4474,9 @@ def main( R.output(gv) return gv - norms = [ - ((float('inf'), None, False), Expected1), - ((float('-inf'), None, False), Expected2), + ((float("inf"), None, False), Expected1), + ((float("-inf"), None, False), Expected2), ((float(2), None, False), Expected3), ((float(1.0), None, False), Expected4), ((float(-4), None, True), Expected5), From c994140065efc286ddefea79cd8acd2fa30ffa93 Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Fri, 18 Apr 2025 22:54:18 +0800 Subject: [PATCH 11/20] Update test_frontend_from_fx.py --- tests/python/relax/test_frontend_from_fx.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 83ab4a0a32dd..ebb7ceb68cf9 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -4910,13 +4910,13 @@ def main( return gv norms = [ - ((float('inf'), None, False), Expected1), - ((float('-inf'), None, False), Expected2), + ((float("inf"), None, False), Expected1), + ((float("-inf"), None, False), Expected2), ((float(2), None, False), Expected3), ((float(1.0), None, False), Expected4), ((float(-4), None, True), Expected5), ((float(0.5), None, True), Expected6), - (("fro", None, False), Expected7) + (("fro", None, False), Expected7), ] for (p, dim, keepdim), expected in norms: From 74eeb17df175de5b1c6774d9cff64b16e5705e1c Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Sat, 19 Apr 2025 16:23:42 +0800 Subject: [PATCH 12/20] Update fx_translator.py --- python/tvm/relax/frontend/torch/fx_translator.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index fd2bc62f5380..38e18decc8b5 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -447,13 +447,15 @@ def _norm(self, node: fx.Node) -> relax.Var: ) ) else: - reci_order = relax.const(1 / order, dtype=dtype) - order = relax.const(order, dtype=dtype) + ord_expr = (order if isinstance(order, relax.Expr) else relax.const(float(order), dtype=dtype) ) + reci_order = ( + relax.op.divide(relax.const(1.0, dtype), ord_expr) + if isinstance(order, relax.Expr) + else relax.const(1.0 / order, dtype=dtype) + ) return self.block_builder.emit( relax.op.power( - relax.op.sum( - relax.op.power(relax.op.abs(data), order), axis=axis, keepdims=keepdims - ), + relax.op.sum(relax.op.power(relax.op.abs(data), ord_expr), axis=axis, keepdims=keepdims), reci_order, ) ) From c47dce193d50916c9c9cc86b5b2b8c88af2d8538 Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Sat, 19 Apr 2025 16:25:59 +0800 Subject: [PATCH 13/20] Update exported_program_translator.py --- .../torch/exported_program_translator.py | 30 ------------------- 1 file changed, 30 deletions(-) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 8f6b7169aeec..932607287571 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -200,35 +200,6 @@ def _upsample_nearest2d(self, node: fx.node) -> relax.Var: align_corners=align_corners, ) - ########## Statistical ########## - - def _norm(self, node: fx.Node) -> relax.Var: - data = self.env[node.args[0]] - dtype = data.struct_info.dtype - order = node.args[1] if len(node.args) > 1 else node.kwargs.get("p", 2) - axis = node.args[2] if len(node.args) > 2 else None - keepdims = node.args[3] if len(node.args) > 3 else False - - if order == float("inf"): - return self.block_builder.emit( - relax.op.max(relax.op.abs(data), axis=axis, keepdims=keepdims) - ) - elif order == float("-inf"): - return self.block_builder.emit( - relax.op.min(relax.op.abs(data), axis=axis, keepdims=keepdims) - ) - else: - reci_order = relax.const(1 / order, dtype=dtype) - order = relax.const(order, dtype=dtype) - return self.block_builder.emit( - relax.op.power( - relax.op.sum( - relax.op.power(relax.op.abs(data), order), axis=axis, keepdims=keepdims - ), - reci_order, - ) - ) - ########## Manipulation ########## def _narrow(self, node: fx.Node) -> relax.Var: @@ -425,7 +396,6 @@ def create_convert_map( "upsample_nearest2d.vec": self._upsample_nearest2d, # statistical "mean.dim": self._mean, - "linalg_vector_norm.default": self._norm, "prod.default": self._prod, "std.correction": self._std, "sum.default": self._sum, From 687e60aaab08478d348e138998bcb78b5f41707a Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Sat, 19 Apr 2025 16:35:50 +0800 Subject: [PATCH 14/20] Update base_fx_graph_translator.py --- .../torch/base_fx_graph_translator.py | 36 ------------------- 1 file changed, 36 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 237b971c7cc9..57a4649caf4a 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -417,42 +417,6 @@ def _rsub(self, node: fx.Node) -> relax.Var: return self.block_builder.emit(relax.op.subtract(rhs, lhs)) - ########## Linear Algebra ########## - - def _linalg_vector_norm(self, node: fx.Node) -> relax.Var: - - args = self.retrieve_args(node) - - data = args[0] - # Default ord=2 if not supplied - ord_val = args[1] if len(args) > 1 else 2.0 - dim = args[2] if len(args) > 2 else None - keepdim = args[3] if len(args) > 3 else False - - # If ord_val is a Python float/int, wrap it in a Relax const - # so that it matches data's dtype. - dtype = data.struct_info.dtype - ord_expr = ( - ord_val if isinstance(ord_val, relax.Expr) else relax.const(float(ord_val), dtype) - ) - # Reciprocal - reci_expr = ( - relax.op.divide(relax.const(1.0, dtype), ord_expr) - if isinstance(ord_val, relax.Expr) - else relax.const(1.0 / float(ord_val), dtype) - ) - - # abs(data) - abs_data = self.block_builder.emit(relax.op.abs(data)) - # abs_data^ord - abs_data_pow = self.block_builder.emit(relax.op.power(abs_data, ord_expr)) - # sum over dim - reduced = self.block_builder.emit(relax.op.sum(abs_data_pow, dim, keepdims=keepdim)) - # (sum(...))^(1/ord) - norm_val = self.block_builder.emit(relax.op.power(reduced, reci_expr)) - - return norm_val - ########## Neural Network ########## def _adaptive_avg_pool2d(self, node: fx.Node) -> relax.Var: From 3ce1a3a58e4dd09e1f35467e4ba3befee14823e6 Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Sat, 19 Apr 2025 16:37:52 +0800 Subject: [PATCH 15/20] Update exported_program_translator.py --- .../torch/exported_program_translator.py | 29 ++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 932607287571..6653365ef5e6 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -64,6 +64,33 @@ def _reciprocal(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] return self.block_builder.emit(relax.op.divide(relax.const(1.0, x.struct_info.dtype), x)) + ########## Linear Algebra ########## + + def _linalg_vector_norm(self, node: fx.Node) -> relax.Var: + data = self.env[node.args[0]] + dtype = data.struct_info.dtype + order = node.args[1] if len(node.args) > 1 else node.kwargs.get("p", 2) + axis = node.args[2] if len(node.args) > 2 else None + keepdims = node.args[3] if len(node.args) > 3 else False + + if order == float("inf"): + return self.block_builder.emit(relax.op.max(relax.op.abs(data), axis=axis, keepdims=keepdims)) + elif order == float("-inf"): + return self.block_builder.emit(relax.op.min(relax.op.abs(data), axis=axis, keepdims=keepdims)) + else: + ord_expr = (order if isinstance(order, relax.Expr) else relax.const(float(order), dtype=dtype) ) + reci_order = ( + relax.op.divide(relax.const(1.0, dtype), ord_expr) + if isinstance(order, relax.Expr) + else relax.const(1.0 / order, dtype=dtype) + ) + return self.block_builder.emit( + relax.op.power( + relax.op.sum(relax.op.power(relax.op.abs(data), ord_expr), axis=axis, keepdims=keepdims), + reci_order, + ) + ) + ########## Neural Network ########## def _batch_norm(self, node: fx.Node, training) -> relax.Var: @@ -199,7 +226,7 @@ def _upsample_nearest2d(self, node: fx.node) -> relax.Var: method="nearest_neighbor", align_corners=align_corners, ) - + ########## Manipulation ########## def _narrow(self, node: fx.Node) -> relax.Var: From d260e38eb75719da809a4fcb3af03aaf4e65e73f Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Sat, 19 Apr 2025 17:59:45 +0800 Subject: [PATCH 16/20] Update exported_program_translator.py --- .../torch/exported_program_translator.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 6653365ef5e6..5fce237a7b76 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -74,11 +74,17 @@ def _linalg_vector_norm(self, node: fx.Node) -> relax.Var: keepdims = node.args[3] if len(node.args) > 3 else False if order == float("inf"): - return self.block_builder.emit(relax.op.max(relax.op.abs(data), axis=axis, keepdims=keepdims)) + return self.block_builder.emit( + relax.op.max(relax.op.abs(data), axis=axis, keepdims=keepdims) + ) elif order == float("-inf"): - return self.block_builder.emit(relax.op.min(relax.op.abs(data), axis=axis, keepdims=keepdims)) + return self.block_builder.emit( + relax.op.min(relax.op.abs(data), axis=axis, keepdims=keepdims) + ) else: - ord_expr = (order if isinstance(order, relax.Expr) else relax.const(float(order), dtype=dtype) ) + ord_expr = ( + order if isinstance(order, relax.Expr) else relax.const(float(order), dtype=dtype) + ) reci_order = ( relax.op.divide(relax.const(1.0, dtype), ord_expr) if isinstance(order, relax.Expr) @@ -86,7 +92,9 @@ def _linalg_vector_norm(self, node: fx.Node) -> relax.Var: ) return self.block_builder.emit( relax.op.power( - relax.op.sum(relax.op.power(relax.op.abs(data), ord_expr), axis=axis, keepdims=keepdims), + relax.op.sum( + relax.op.power(relax.op.abs(data), ord_expr), axis=axis, keepdims=keepdims + ), reci_order, ) ) @@ -226,7 +234,7 @@ def _upsample_nearest2d(self, node: fx.node) -> relax.Var: method="nearest_neighbor", align_corners=align_corners, ) - + ########## Manipulation ########## def _narrow(self, node: fx.Node) -> relax.Var: From e0f6eeaf4ef9a0d7eb59262f5ac95a037cc6fde5 Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Sat, 19 Apr 2025 18:01:14 +0800 Subject: [PATCH 17/20] Update fx_translator.py --- python/tvm/relax/frontend/torch/fx_translator.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 38e18decc8b5..6d7d1a92518b 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -447,7 +447,9 @@ def _norm(self, node: fx.Node) -> relax.Var: ) ) else: - ord_expr = (order if isinstance(order, relax.Expr) else relax.const(float(order), dtype=dtype) ) + ord_expr = ( + order if isinstance(order, relax.Expr) else relax.const(float(order), dtype=dtype) + ) reci_order = ( relax.op.divide(relax.const(1.0, dtype), ord_expr) if isinstance(order, relax.Expr) @@ -455,7 +457,9 @@ def _norm(self, node: fx.Node) -> relax.Var: ) return self.block_builder.emit( relax.op.power( - relax.op.sum(relax.op.power(relax.op.abs(data), ord_expr), axis=axis, keepdims=keepdims), + relax.op.sum( + relax.op.power(relax.op.abs(data), ord_expr), axis=axis, keepdims=keepdims + ), reci_order, ) ) From b4106fa94026d736301a0f4af8d0c2d81ac3d15d Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Mon, 21 Apr 2025 23:57:14 +0800 Subject: [PATCH 18/20] Update fx_translator.py --- .../tvm/relax/frontend/torch/fx_translator.py | 42 ------------------- 1 file changed, 42 deletions(-) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index df76309f57b0..548320bd854e 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -422,48 +422,6 @@ def _lerp(self, node: fx.Node) -> relax.Var: relax.op.add(start, relax.op.multiply(weight, relax.op.subtract(end, start))) ) - ########## Statistical ########## - - def _norm(self, node: fx.Node) -> relax.Var: - data = self.env[node.args[0]] - dtype = data.struct_info.dtype - order = node.args[1] if len(node.args) > 1 else node.kwargs.get("p", 2) - axis = node.args[2] if len(node.args) > 2 else None - keepdims = node.args[3] if len(node.args) > 3 else False - - if order == float("inf"): - return self.block_builder.emit( - relax.op.max(relax.op.abs(data), axis=axis, keepdims=keepdims) - ) - elif order == float("-inf"): - return self.block_builder.emit( - relax.op.min(relax.op.abs(data), axis=axis, keepdims=keepdims) - ) - # frobenius_norm - elif order == "fro": - return self.block_builder.emit( - relax.op.sqrt( - relax.op.sum(relax.op.multiply(data, data), axis=axis, keepdims=keepdims) - ) - ) - else: - ord_expr = ( - order if isinstance(order, relax.Expr) else relax.const(float(order), dtype=dtype) - ) - reci_order = ( - relax.op.divide(relax.const(1.0, dtype), ord_expr) - if isinstance(order, relax.Expr) - else relax.const(1.0 / order, dtype=dtype) - ) - return self.block_builder.emit( - relax.op.power( - relax.op.sum( - relax.op.power(relax.op.abs(data), ord_expr), axis=axis, keepdims=keepdims - ), - reci_order, - ) - ) - ########## Manipulation ########## def _chunk(self, node: fx.Node) -> relax.Var: From 181d64c4b606d5f2c84f336498d2911f8cb06b98 Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Mon, 21 Apr 2025 23:58:01 +0800 Subject: [PATCH 19/20] Update base_fx_graph_translator.py --- .../torch/base_fx_graph_translator.py | 40 +++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index f91c361829d5..b0bdc598f4e4 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -925,6 +925,46 @@ def _mean(self, node: fx.Node) -> relax.Var: keepdim = args[2] if len(node.args) > 2 else node.kwargs.get("keepdim", False) return self.block_builder.emit(relax.op.mean(x, dim, keepdims=keepdim)) + def _norm(self, node: fx.Node) -> relax.Var: + data = self.env[node.args[0]] + dtype = data.struct_info.dtype + order = node.args[1] if len(node.args) > 1 else node.kwargs.get("p", 2) + axis = node.args[2] if len(node.args) > 2 else None + keepdims = node.args[3] if len(node.args) > 3 else False + + if order == float("inf"): + return self.block_builder.emit( + relax.op.max(relax.op.abs(data), axis=axis, keepdims=keepdims) + ) + elif order == float("-inf"): + return self.block_builder.emit( + relax.op.min(relax.op.abs(data), axis=axis, keepdims=keepdims) + ) + # frobenius_norm + elif order == "fro": + return self.block_builder.emit( + relax.op.sqrt( + relax.op.sum(relax.op.multiply(data, data), axis=axis, keepdims=keepdims) + ) + ) + else: + ord_expr = ( + order if isinstance(order, relax.Expr) else relax.const(float(order), dtype=dtype) + ) + reci_order = ( + relax.op.divide(relax.const(1.0, dtype), ord_expr) + if isinstance(order, relax.Expr) + else relax.const(1.0 / order, dtype=dtype) + ) + return self.block_builder.emit( + relax.op.power( + relax.op.sum( + relax.op.power(relax.op.abs(data), ord_expr), axis=axis, keepdims=keepdims + ), + reci_order, + ) + ) + def _prod(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) x = args[0] From 70ba7965d3d760852ed799d37b88897ab1a8f1f4 Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Mon, 21 Apr 2025 23:59:20 +0800 Subject: [PATCH 20/20] Update exported_program_translator.py --- .../torch/exported_program_translator.py | 37 +------------------ 1 file changed, 1 insertion(+), 36 deletions(-) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 1596325b5a66..a178982acd06 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -64,41 +64,6 @@ def _reciprocal(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] return self.block_builder.emit(relax.op.divide(relax.const(1.0, x.struct_info.dtype), x)) - ########## Linear Algebra ########## - - def _linalg_vector_norm(self, node: fx.Node) -> relax.Var: - data = self.env[node.args[0]] - dtype = data.struct_info.dtype - order = node.args[1] if len(node.args) > 1 else node.kwargs.get("p", 2) - axis = node.args[2] if len(node.args) > 2 else None - keepdims = node.args[3] if len(node.args) > 3 else False - - if order == float("inf"): - return self.block_builder.emit( - relax.op.max(relax.op.abs(data), axis=axis, keepdims=keepdims) - ) - elif order == float("-inf"): - return self.block_builder.emit( - relax.op.min(relax.op.abs(data), axis=axis, keepdims=keepdims) - ) - else: - ord_expr = ( - order if isinstance(order, relax.Expr) else relax.const(float(order), dtype=dtype) - ) - reci_order = ( - relax.op.divide(relax.const(1.0, dtype), ord_expr) - if isinstance(order, relax.Expr) - else relax.const(1.0 / order, dtype=dtype) - ) - return self.block_builder.emit( - relax.op.power( - relax.op.sum( - relax.op.power(relax.op.abs(data), ord_expr), axis=axis, keepdims=keepdims - ), - reci_order, - ) - ) - ########## Neural Network ########## def _batch_norm(self, node: fx.Node, training) -> relax.Var: @@ -404,7 +369,7 @@ def create_convert_map( "__xor__.Tensor": self._binary_op(relax.op.bitwise_xor, operator.xor), "__xor__.Scalar": self._binary_op(relax.op.bitwise_xor, operator.xor), # linear algebra - "linalg_vector_norm.default": self._linalg_vector_norm, + "linalg_vector_norm.default": self._norm, # neural network "_native_batch_norm_legit_functional.default": self._batch_norm_legit_functional, "_native_batch_norm_legit_no_training.default": self._batch_norm_legit_no_training,