Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions python/tvm/relax/frontend/torch/base_fx_graph_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,39 @@ def _softmax(self, node: fx.Node) -> relax.Var:
dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", -1)
return self.block_builder.emit(relax.op.nn.softmax(x, dim))

def _softshrink(self, node: fx.Node) -> relax.Var:
"""
Applies the Softshrink activation function in Relax.

Softshrink(x) =
x - λ if x > λ
x + λ if x < -λ
0 otherwise

Args:
node (fx.Node): The input node containing the tensor and lambda value.

Returns:
relax.Var: The resulting tensor after applying Softshrink.
"""
args = self.retrieve_args(node)
x = args[0]
lambd = relax.const(args[1] if len(args) > 1 else 0.5, x.struct_info.dtype)

# Apply Softshrink transformation with masking
shrink_pos = relax.op.multiply(
relax.op.subtract(x, lambd),
relax.op.astype(relax.op.greater(x, lambd), x.struct_info.dtype),
)

shrink_neg = relax.op.multiply(
relax.op.add(x, lambd),
relax.op.astype(relax.op.less(x, relax.op.negative(lambd)), x.struct_info.dtype),
)

# Combine the positive and negative shrink results
return self.block_builder.emit(relax.op.add(shrink_pos, shrink_neg))

def _selu(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
alpha = node.args[1] if len(node.args) > 1 else node.kwargs.get("alpha", 1.6732631921768188)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ def create_convert_map(
"sin.default": self._unary_op(relax.op.sin),
"sinh.default": self._unary_op(relax.op.sinh),
"softmax.int": self._softmax,
"softshrink.default": self._softshrink,
"sqrt.default": self._unary_op(relax.op.sqrt),
"square.default": self._unary_op(relax.op.square),
"tan.default": self._unary_op(relax.op.tan),
Expand Down
51 changes: 51 additions & 0 deletions tests/python/relax/test_frontend_from_exported_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,9 @@ def main(
# softmax
test_softmax()

# softshrink
test_softshrink()

# tril, triu
test_tril_triu()

Expand Down Expand Up @@ -741,6 +744,54 @@ def main(
verify_model(Softmax2(), example_args, {}, expected1)


def test_softshrink():
class Softshrink(Module):
def __init__(self):
super().__init__()
self.softshrink = torch.nn.Softshrink(lambd=0.5)

def forward(self, input):
return self.softshrink(input)

class Softshrink2(Module):
def forward(self, input):
return torch.nn.functional.softshrink(input, lambd=0.5)

@tvm.script.ir_module
class expected_softshrink:
@R.function
def main(
input: R.Tensor((1, 3, 10, 10), dtype="float32"),
) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
with R.dataflow():
lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.subtract(
input, R.const(0.5, "float32")
)
lv1: R.Tensor((1, 3, 10, 10), dtype="bool") = R.greater(
input, R.const(0.5, "float32")
)
lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.astype(lv1, "float32")
lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply(lv, lv2)

lv4: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(
input, R.const(0.5, "float32")
)
lv5: R.Tensor((), dtype="float32") = R.negative(R.const(0.5, "float32"))
lv6: R.Tensor((1, 3, 10, 10), dtype="bool") = R.less(input, lv5)
lv7: R.Tensor((1, 3, 10, 10), dtype="float32") = R.astype(lv6, "float32")
lv8: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply(lv4, lv7)

lv9: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(lv3, lv8)

gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv9,)
R.output(gv)
return gv

example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
verify_model(Softshrink(), example_args, {}, expected_softshrink)
verify_model(Softshrink2(), example_args, {}, expected_softshrink)


def test_tril_triu():
example_args = (torch.randn(10, 10, dtype=torch.float32),)

Expand Down
Loading