From 0da2cb72f0d4a469288113599ea55990178ee0cc Mon Sep 17 00:00:00 2001 From: deivanayakisankaralingam Date: Fri, 28 Mar 2025 08:59:26 +0000 Subject: [PATCH 1/4] softshrink op support into exported program and test script code added --- .../torch/base_fx_graph_translator.py | 33 ++++++++++++++ .../torch/exported_program_translator.py | 1 + .../test_frontend_from_exported_program.py | 43 +++++++++++++++++++ 3 files changed, 77 insertions(+) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 71554a8a5bab..e3d4a79f2265 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -307,6 +307,39 @@ def _softmax(self, node: fx.Node) -> relax.Var: dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", -1) return self.block_builder.emit(relax.op.nn.softmax(x, dim)) + def _softshrink(self, node: fx.Node) -> relax.Var: + """ + Applies the Softshrink activation function in Relax. + + Softshrink(x) = + x - λ if x > λ + x + λ if x < -λ + 0 otherwise + + Args: + node (fx.Node): The input node containing the tensor and lambda value. + + Returns: + relax.Var: The resulting tensor after applying Softshrink. + """ + args = self.retrieve_args(node) + x = args[0] + lambd = relax.const(args[1] if len(args) > 1 else 0.5, x.struct_info.dtype) + + # Apply Softshrink transformation with masking + shrink_pos = relax.op.multiply( + relax.op.subtract(x, lambd), + relax.op.astype(relax.op.greater(x, lambd), x.struct_info.dtype) + ) + + shrink_neg = relax.op.multiply( + relax.op.add(x, lambd), + relax.op.astype(relax.op.less(x, relax.op.negative(lambd)), x.struct_info.dtype) + ) + + # Combine the positive and negative shrink results + return self.block_builder.emit(relax.op.add(shrink_pos, shrink_neg)) + def _selu(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] alpha = node.args[1] if len(node.args) > 1 else node.kwargs.get("alpha", 1.6732631921768188) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 0f1dc11787da..a28da6ee72e9 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -282,6 +282,7 @@ def create_convert_map( "sin.default": self._unary_op(relax.op.sin), "sinh.default": self._unary_op(relax.op.sinh), "softmax.int": self._softmax, + "softshrink.default": self._softshrink, "sqrt.default": self._unary_op(relax.op.sqrt), "square.default": self._unary_op(relax.op.square), "tan.default": self._unary_op(relax.op.tan), diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 739fe87dc931..66b635696333 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -607,6 +607,9 @@ def main( # softmax test_softmax() + # softshrink + test_softshrink() + # tril, triu test_tril_triu() @@ -740,6 +743,46 @@ def main( verify_model(Softmax(), example_args, {}, expected1) verify_model(Softmax2(), example_args, {}, expected1) +def test_softshrink(): + class Softshrink(Module): + def __init__(self): + super().__init__() + self.softshrink = torch.nn.Softshrink(lambd=0.5) + + def forward(self, input): + return self.softshrink(input) + + class Softshrink2(Module): + def forward(self, input): + return torch.nn.functional.softshrink(input, lambd=0.5) + + @tvm.script.ir_module + class expected_softshrink: + @R.function + def main( + input: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.subtract(input, R.const(0.5, "float32")) + lv1: R.Tensor((1, 3, 10, 10), dtype="bool") = R.greater(input, R.const(0.5, "float32")) + lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.astype(lv1, "float32") + lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply(lv, lv2) + + lv4: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(input, R.const(0.5, "float32")) + lv5: R.Tensor((), dtype="float32") = R.negative(R.const(0.5, "float32")) + lv6: R.Tensor((1, 3, 10, 10), dtype="bool") = R.less(input, lv5) + lv7: R.Tensor((1, 3, 10, 10), dtype="float32") = R.astype(lv6, "float32") + lv8: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply(lv4, lv7) + + lv9: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(lv3, lv8) + + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv9,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + verify_model(Softshrink(), example_args, {}, expected_softshrink) + verify_model(Softshrink2(), example_args, {}, expected_softshrink) def test_tril_triu(): example_args = (torch.randn(10, 10, dtype=torch.float32),) From 2e2c9b1698b67ce83f9fd315bfb93f1f49fe4a25 Mon Sep 17 00:00:00 2001 From: deivanayakisankaralingam Date: Mon, 31 Mar 2025 05:47:09 +0000 Subject: [PATCH 2/4] fix lint issue --- .../torch/base_fx_graph_translator.py | 12 ++++---- .../test_frontend_from_exported_program.py | 28 +++++++++++++------ 2 files changed, 26 insertions(+), 14 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index e3d4a79f2265..5c8c09cc4e8c 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -311,7 +311,7 @@ def _softshrink(self, node: fx.Node) -> relax.Var: """ Applies the Softshrink activation function in Relax. - Softshrink(x) = + Softshrink(x) = x - λ if x > λ x + λ if x < -λ 0 otherwise @@ -328,13 +328,15 @@ def _softshrink(self, node: fx.Node) -> relax.Var: # Apply Softshrink transformation with masking shrink_pos = relax.op.multiply( - relax.op.subtract(x, lambd), - relax.op.astype(relax.op.greater(x, lambd), x.struct_info.dtype) + relax.op.subtract(x, lambd), + relax.op.astype(relax.op.greater(x, lambd), x.struct_info.dtype), ) shrink_neg = relax.op.multiply( - relax.op.add(x, lambd), - relax.op.astype(relax.op.less(x, relax.op.negative(lambd)), x.struct_info.dtype) + relax.op.add(x, lambd), + relax.op.astype( + relax.op.less(x, relax.op.negative(lambd)), x.struct_info.dtype + ), ) # Combine the positive and negative shrink results diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 66b635696333..d5f78be12e9c 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -751,27 +751,37 @@ def __init__(self): def forward(self, input): return self.softshrink(input) - + class Softshrink2(Module): def forward(self, input): return torch.nn.functional.softshrink(input, lambd=0.5) - + @tvm.script.ir_module class expected_softshrink: @R.function def main( - input: R.Tensor((1, 3, 10, 10), dtype="float32") + input: R.Tensor((1, 3, 10, 10), dtype="float32"), ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): with R.dataflow(): - lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.subtract(input, R.const(0.5, "float32")) - lv1: R.Tensor((1, 3, 10, 10), dtype="bool") = R.greater(input, R.const(0.5, "float32")) - lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.astype(lv1, "float32") + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.subtract( + input, R.const(0.5, "float32") + ) + lv1: R.Tensor((1, 3, 10, 10), dtype="bool") = R.greater( + input, R.const(0.5, "float32") + ) + lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.astype( + lv1, "float32" + ) lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply(lv, lv2) - lv4: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(input, R.const(0.5, "float32")) + lv4: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add( + input, R.const(0.5, "float32") + ) lv5: R.Tensor((), dtype="float32") = R.negative(R.const(0.5, "float32")) lv6: R.Tensor((1, 3, 10, 10), dtype="bool") = R.less(input, lv5) - lv7: R.Tensor((1, 3, 10, 10), dtype="float32") = R.astype(lv6, "float32") + lv7: R.Tensor((1, 3, 10, 10), dtype="float32") = R.astype( + lv6, "float32" + ) lv8: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply(lv4, lv7) lv9: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(lv3, lv8) @@ -779,7 +789,7 @@ def main( gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv9,) R.output(gv) return gv - + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) verify_model(Softshrink(), example_args, {}, expected_softshrink) verify_model(Softshrink2(), example_args, {}, expected_softshrink) From d5734e984c3b63ca4ed95e9dc9c798d055fe7587 Mon Sep 17 00:00:00 2001 From: deivanayakisankaralingam Date: Mon, 31 Mar 2025 06:04:34 +0000 Subject: [PATCH 3/4] update the formatting to fix lint issues --- .../tvm/relax/frontend/torch/base_fx_graph_translator.py | 4 +--- tests/python/relax/test_frontend_from_exported_program.py | 8 ++------ 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 5c8c09cc4e8c..466f88ef1867 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -334,9 +334,7 @@ def _softshrink(self, node: fx.Node) -> relax.Var: shrink_neg = relax.op.multiply( relax.op.add(x, lambd), - relax.op.astype( - relax.op.less(x, relax.op.negative(lambd)), x.struct_info.dtype - ), + relax.op.astype(relax.op.less(x, relax.op.negative(lambd)), x.struct_info.dtype), ) # Combine the positive and negative shrink results diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index d5f78be12e9c..2fe89a8a2b76 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -769,9 +769,7 @@ def main( lv1: R.Tensor((1, 3, 10, 10), dtype="bool") = R.greater( input, R.const(0.5, "float32") ) - lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.astype( - lv1, "float32" - ) + lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.astype(lv1, "float32") lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply(lv, lv2) lv4: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add( @@ -779,9 +777,7 @@ def main( ) lv5: R.Tensor((), dtype="float32") = R.negative(R.const(0.5, "float32")) lv6: R.Tensor((1, 3, 10, 10), dtype="bool") = R.less(input, lv5) - lv7: R.Tensor((1, 3, 10, 10), dtype="float32") = R.astype( - lv6, "float32" - ) + lv7: R.Tensor((1, 3, 10, 10), dtype="float32") = R.astype(lv6, "float32") lv8: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply(lv4, lv7) lv9: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(lv3, lv8) From 18981fe58672cd3abbc41bbd1eea36e563a25f43 Mon Sep 17 00:00:00 2001 From: deivanayakisankaralingam Date: Mon, 31 Mar 2025 06:28:22 +0000 Subject: [PATCH 4/4] modify the code format to fix lint issue --- python/tvm/relax/frontend/torch/base_fx_graph_translator.py | 2 +- tests/python/relax/test_frontend_from_exported_program.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 466f88ef1867..fee8cf9de9a6 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -339,7 +339,7 @@ def _softshrink(self, node: fx.Node) -> relax.Var: # Combine the positive and negative shrink results return self.block_builder.emit(relax.op.add(shrink_pos, shrink_neg)) - + def _selu(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] alpha = node.args[1] if len(node.args) > 1 else node.kwargs.get("alpha", 1.6732631921768188) diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 2fe89a8a2b76..98f0f1d9cac6 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -743,6 +743,7 @@ def main( verify_model(Softmax(), example_args, {}, expected1) verify_model(Softmax2(), example_args, {}, expected1) + def test_softshrink(): class Softshrink(Module): def __init__(self): @@ -790,6 +791,7 @@ def main( verify_model(Softshrink(), example_args, {}, expected_softshrink) verify_model(Softshrink2(), example_args, {}, expected_softshrink) + def test_tril_triu(): example_args = (torch.randn(10, 10, dtype=torch.float32),)