diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 71554a8a5bab..fee8cf9de9a6 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -307,6 +307,39 @@ def _softmax(self, node: fx.Node) -> relax.Var: dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", -1) return self.block_builder.emit(relax.op.nn.softmax(x, dim)) + def _softshrink(self, node: fx.Node) -> relax.Var: + """ + Applies the Softshrink activation function in Relax. + + Softshrink(x) = + x - λ if x > λ + x + λ if x < -λ + 0 otherwise + + Args: + node (fx.Node): The input node containing the tensor and lambda value. + + Returns: + relax.Var: The resulting tensor after applying Softshrink. + """ + args = self.retrieve_args(node) + x = args[0] + lambd = relax.const(args[1] if len(args) > 1 else 0.5, x.struct_info.dtype) + + # Apply Softshrink transformation with masking + shrink_pos = relax.op.multiply( + relax.op.subtract(x, lambd), + relax.op.astype(relax.op.greater(x, lambd), x.struct_info.dtype), + ) + + shrink_neg = relax.op.multiply( + relax.op.add(x, lambd), + relax.op.astype(relax.op.less(x, relax.op.negative(lambd)), x.struct_info.dtype), + ) + + # Combine the positive and negative shrink results + return self.block_builder.emit(relax.op.add(shrink_pos, shrink_neg)) + def _selu(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] alpha = node.args[1] if len(node.args) > 1 else node.kwargs.get("alpha", 1.6732631921768188) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 0f1dc11787da..a28da6ee72e9 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -282,6 +282,7 @@ def create_convert_map( "sin.default": self._unary_op(relax.op.sin), "sinh.default": self._unary_op(relax.op.sinh), "softmax.int": self._softmax, + "softshrink.default": self._softshrink, "sqrt.default": self._unary_op(relax.op.sqrt), "square.default": self._unary_op(relax.op.square), "tan.default": self._unary_op(relax.op.tan), diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 739fe87dc931..98f0f1d9cac6 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -607,6 +607,9 @@ def main( # softmax test_softmax() + # softshrink + test_softshrink() + # tril, triu test_tril_triu() @@ -741,6 +744,54 @@ def main( verify_model(Softmax2(), example_args, {}, expected1) +def test_softshrink(): + class Softshrink(Module): + def __init__(self): + super().__init__() + self.softshrink = torch.nn.Softshrink(lambd=0.5) + + def forward(self, input): + return self.softshrink(input) + + class Softshrink2(Module): + def forward(self, input): + return torch.nn.functional.softshrink(input, lambd=0.5) + + @tvm.script.ir_module + class expected_softshrink: + @R.function + def main( + input: R.Tensor((1, 3, 10, 10), dtype="float32"), + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.subtract( + input, R.const(0.5, "float32") + ) + lv1: R.Tensor((1, 3, 10, 10), dtype="bool") = R.greater( + input, R.const(0.5, "float32") + ) + lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.astype(lv1, "float32") + lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply(lv, lv2) + + lv4: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add( + input, R.const(0.5, "float32") + ) + lv5: R.Tensor((), dtype="float32") = R.negative(R.const(0.5, "float32")) + lv6: R.Tensor((1, 3, 10, 10), dtype="bool") = R.less(input, lv5) + lv7: R.Tensor((1, 3, 10, 10), dtype="float32") = R.astype(lv6, "float32") + lv8: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply(lv4, lv7) + + lv9: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(lv3, lv8) + + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv9,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + verify_model(Softshrink(), example_args, {}, expected_softshrink) + verify_model(Softshrink2(), example_args, {}, expected_softshrink) + + def test_tril_triu(): example_args = (torch.randn(10, 10, dtype=torch.float32),)