From 8e384a03b10502eabd809be8e488be3f6311df18 Mon Sep 17 00:00:00 2001 From: deivanayakisankaralingam Date: Tue, 8 Apr 2025 09:13:16 +0000 Subject: [PATCH 01/11] prelu op support and test script added --- include/tvm/relax/attrs/nn.h | 10 ++++++ .../torch/base_fx_graph_translator.py | 6 ++++ .../torch/exported_program_translator.py | 1 + .../tvm/relax/frontend/torch/fx_translator.py | 10 ++++++ python/tvm/relax/op/nn/__init__.py | 1 + python/tvm/relax/op/nn/nn.py | 26 ++++++++++++++ python/tvm/relax/transform/legalize_ops/nn.py | 5 +++ python/tvm/topi/nn/elemwise.py | 7 +++- src/relax/op/nn/nn.cc | 20 +++++++++++ src/relax/op/nn/nn.h | 3 ++ .../test_frontend_from_exported_program.py | 32 +++++++++++++++++ tests/python/relax/test_frontend_from_fx.py | 35 +++++++++++++++++++ tests/python/relax/test_op_nn.py | 4 +++ 13 files changed, 159 insertions(+), 1 deletion(-) diff --git a/include/tvm/relax/attrs/nn.h b/include/tvm/relax/attrs/nn.h index 0adcf29772cd..163e41d87d68 100644 --- a/include/tvm/relax/attrs/nn.h +++ b/include/tvm/relax/attrs/nn.h @@ -468,6 +468,16 @@ struct SoftplusAttrs : public tvm::AttrsNode { } }; +/*! \brief Attributes used in PReLU operator */ +struct PReluAttrs : public tvm::AttrsNode { + int axis; + + TVM_DECLARE_ATTRS(PReluAttrs, "relax.attrs.PReluAttrs") { + TVM_ATTR_FIELD(axis) + .describe("The axis along which the alpha values are applied."); + } +}; + /*! \brief Attributes used in batch_norm operator */ struct BatchNormAttrs : public tvm::AttrsNode { int axis; diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 4c9480b58748..21cbd14d7ea9 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -307,6 +307,12 @@ def _log_softmax(self, node: fx.Node) -> relax.Var: dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", -1) return self.block_builder.emit(relax.op.nn.log_softmax(x, dim)) + def _prelu(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + alpha = self.env[node.args[1]] + axis = 0 if len(x.struct_info.shape) == 1 else 1 + return self.block_builder.emit(relax.op.nn.prelu(x, alpha, axis)) + def _round(self, node: fx.Node) -> relax.Expr: if node.kwargs.get("decimals", 0) != 0: raise ValueError("specifying decimals for round is not supported yet") diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index c82a5e2b1100..2c9e255f2946 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -299,6 +299,7 @@ def create_convert_map( "log1p.default": self._log1p, "log_softmax.int": self._log_softmax, "neg.default": self._unary_op(relax.op.negative), + "prelu.default": self._prelu, "reciprocal.default": self._reciprocal, "relu.default": self._unary_op(relax.op.nn.relu), "relu_.default": self._unary_op(relax.op.nn.relu), diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 297529e8bf29..499e9fc7774c 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -102,6 +102,14 @@ def _log_softmax_module(self, node: fx.Node) -> relax.Var: dim = module.dim assert dim is not None return self.block_builder.emit(relax.op.nn.log_softmax(x, dim)) + + def _prelu_module(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + module = self.named_modules[node.target] + alpha_tensor = module.weight.numpy() + alpha = relax.const(alpha_tensor, dtype="float32") + axis = 0 if len(x.struct_info.shape) == 1 else 1 + return self.block_builder.emit(relax.op.nn.prelu(x, alpha, axis)) def _softmax_module(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] @@ -595,6 +603,7 @@ def create_convert_map( nn.Identity: lambda node: self.env[node.args[0]], nn.LeakyReLU: self._leakyrelu_module, nn.LogSoftmax: self._log_softmax_module, + nn.PReLU: self._prelu_module, nn.ReLU: self._unary_op(relax.op.nn.relu), nn.ReLU6: lambda node: self.block_builder.emit( relax.op.clip(self.env[node.args[0]], 0, 6) @@ -657,6 +666,7 @@ def create_convert_map( "logical_not": self._unary_op(relax.op.logical_not), "log_softmax": self._log_softmax, "neg": self._unary_op(relax.op.negative), + "prelu":self._prelu, "reciprocal": self._reciprocal, "relu": self._unary_op(relax.op.nn.relu), "round": self._round, diff --git a/python/tvm/relax/op/nn/__init__.py b/python/tvm/relax/op/nn/__init__.py index 9d56058e4649..14b5dcfc0681 100644 --- a/python/tvm/relax/op/nn/__init__.py +++ b/python/tvm/relax/op/nn/__init__.py @@ -43,6 +43,7 @@ max_pool3d, nll_loss, pad, + prelu, relu, rms_norm, selu, diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py index 17197b010ef6..eb76e4d9bb14 100644 --- a/python/tvm/relax/op/nn/nn.py +++ b/python/tvm/relax/op/nn/nn.py @@ -1431,6 +1431,32 @@ def log_softmax(data: Expr, axis: int = -1) -> Expr: return _ffi_api.log_softmax(data, axis) # type: ignore +def prelu(data: Expr, alpha: Expr, axis: int = 1) -> Expr: + r"""Parametric Rectified Linear Unit (PReLU). + + .. math:: + PReLU(x) = x \text{ if } x > 0 \text{ else } \alpha * x + + Parameters + ---------- + data : relax.Expr + The input tensor. + + alpha : relax.Expr + The learnable slope tensor, applied channel-wise. + + axis : int + The axis along which the `alpha` values are applied. + Default is 1 (assuming NCHW format). + + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.prelu(data, alpha, axis) + + def batch_norm( data: Expr, gamma: Expr, diff --git a/python/tvm/relax/transform/legalize_ops/nn.py b/python/tvm/relax/transform/legalize_ops/nn.py index 98fa3ef1ea5e..5d942e5f645d 100644 --- a/python/tvm/relax/transform/legalize_ops/nn.py +++ b/python/tvm/relax/transform/legalize_ops/nn.py @@ -469,6 +469,11 @@ def _nn_leakyrelu(bb: BlockBuilder, call: Call) -> Expr: return bb.call_te(topi.nn.leaky_relu, call.args[0], call.attrs.alpha) +@register_legalize("relax.nn.prelu") +def _nn_prelu(bb: BlockBuilder, call: Call) -> Expr: + return bb.call_te(topi.nn.prelu, call.args[0], call.args[1], call.attrs.axis) + + @register_legalize("relax.nn.gelu") def _nn_gelu(bb: BlockBuilder, call: Call) -> Expr: def te_gelu(x: te.Tensor): diff --git a/python/tvm/topi/nn/elemwise.py b/python/tvm/topi/nn/elemwise.py index 2b174f8f1ed5..8d23db848d16 100644 --- a/python/tvm/topi/nn/elemwise.py +++ b/python/tvm/topi/nn/elemwise.py @@ -129,10 +129,15 @@ def prelu(x, slope, axis=1): assert len(slope.shape) == 1 assert axis < len(x.shape) + slope = te.compute( + (get_const_int(x.shape[axis]),), + lambda c: slope[0], + name="slope_broadcasted" + ) assert get_const_int(slope.shape[0]) == get_const_int(x.shape[axis]) def _compute_channelwise(*indices): xval = x(*indices) return tvm.tir.Select(xval > 0, xval, xval * slope(indices[axis])) - return te.compute(x.shape, _compute_channelwise) + return te.compute(x.shape, _compute_channelwise) \ No newline at end of file diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc index 7f545af1301d..797b98b0afc5 100644 --- a/src/relax/op/nn/nn.cc +++ b/src/relax/op/nn/nn.cc @@ -79,6 +79,26 @@ TVM_REGISTER_OP("relax.nn.softplus") .set_attrs_type() .set_attr("FInferStructInfo", InferStructInfoUnaryArith) + +/* relax.nn.prelu */ +TVM_REGISTER_NODE_TYPE(PReluAttrs); + +Expr prelu(Expr data, Expr alpha, int axis = 1) { + auto attrs = make_object(); + attrs->axis = axis; + static const Op& op = Op::Get("relax.nn.prelu"); + return Call(op, {data, alpha}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.nn.prelu").set_body_typed(prelu); + +TVM_REGISTER_OP("relax.nn.prelu") + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("alpha", "Tensor", "The channel-wise learnable slope.") + .set_attrs_type() + .set_attr("FInferStructInfo", + InferStructInfoUnaryArith) .set_attr("FPurity", Bool(true)); /* relax.nn.softmax */ diff --git a/src/relax/op/nn/nn.h b/src/relax/op/nn/nn.h index 3f5571af8207..a9c3dd0a5767 100644 --- a/src/relax/op/nn/nn.h +++ b/src/relax/op/nn/nn.h @@ -57,6 +57,9 @@ Expr gelu(Expr data); /*! \brief Gaussian Error Linear Units function approximated by tanh. */ Expr gelu_tanh(Expr data); +/*! \brief Parametric Rectified Linear Unit function.*/ +Expr prelu(Expr data, Expr alpha, int axis); + /*! \brief Scaled Exponential Linear Unit function. */ Expr selu(Expr data); diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 26d3d3f7bde2..fb263b81284a 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -642,6 +642,38 @@ def main( verify_model(LogSoftmax2(), example_args, {}, expected1) +def test_prelu(): + class Prelu1(Module): + def __init__(self, num_parameters=1, alpha=0.25): + super().__init__() + self.prelu = torch.nn.PReLU(num_parameters=num_parameters, init=alpha) + + def forward(self, x): + return self.prelu(x) + + class Prelu2(torch.nn.Module): + def __init__(self): + super(Prelu2, self).__init__() + self.alpha = torch.nn.Parameter(torch.tensor([0.25])) + + def forward(self, x): + return torch.nn.functional.prelu(x, self.alpha) + + @tvm.script.ir_module + class expected: + @R.function + def main(x: R.Tensor((1, 3, 10, 10), dtype="float32")) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.prelu(x, R.const([0.25], dtype="float32"), axis=1) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + verify_model(Prelu1(), example_args, {}, expected) + verify_model(Prelu2(), example_args, {}, expected) + + def test_softmax(): class Softmax(Module): def __init__(self): diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index a962de8a3237..dd5923be9e08 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -823,6 +823,38 @@ def main( verify_model(LeakyReLU1(), input_info, {}, expected) +def test_prelu(): + class Prelu1(Module): + def __init__(self, num_parameters=1, alpha=0.25): + super().__init__() + self.prelu = torch.nn.PReLU(num_parameters=num_parameters, init=alpha) + + def forward(self, x): + return self.prelu(x) + + class Prelu2(torch.nn.Module): + def __init__(self): + super(Prelu2, self).__init__() + self.alpha = torch.nn.Parameter(torch.tensor([0.25])) + + def forward(self, x): + return torch.nn.functional.prelu(x, self.alpha) + + @tvm.script.ir_module + class expected: + @R.function + def main(x: R.Tensor((1, 3, 10, 10), dtype="float32")) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.prelu(x, R.const([0.25], dtype="float32"), axis=1) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + input_info = [([1, 3, 10, 10], "float32")] + verify_model(Prelu1(), input_info, {}, expected) + verify_model(Prelu2(), input_info, {}, expected) + + def test_maxpool2d(): input_info = [([1, 3, 10, 10], "float32")] @@ -2266,6 +2298,9 @@ def main( # softplus test_softplus() + # prelu + test_prelu() + # log2 class Log2(Module): def forward(self, x): diff --git a/tests/python/relax/test_op_nn.py b/tests/python/relax/test_op_nn.py index 2401153c61de..0baea433fc48 100644 --- a/tests/python/relax/test_op_nn.py +++ b/tests/python/relax/test_op_nn.py @@ -35,6 +35,10 @@ def test_op_correctness(): assert relax.op.nn.dropout(x).op == Op.get("relax.nn.dropout") assert relax.op.nn.pad(x, (1, 1, 1, 1)).op == Op.get("relax.nn.pad") + x = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32")) + alpha = relax.Var("alpha", R.Tensor((3,), "float32")) + assert relax.op.nn.prelu(x, alpha, axis=1) == Op.get("relax.nn.prelu") + x = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32")) gamma = relax.Var("gamma", R.Tensor((3,), "float32")) beta = relax.Var("beta", R.Tensor((3,), "float32")) From f462156060f10e3676ef7b70c7100ea4935f1f63 Mon Sep 17 00:00:00 2001 From: deivanayakisankaralingam Date: Tue, 8 Apr 2025 09:14:07 +0000 Subject: [PATCH 02/11] end-of-file issue fixed --- python/tvm/topi/nn/elemwise.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/topi/nn/elemwise.py b/python/tvm/topi/nn/elemwise.py index 8d23db848d16..05fa9f7a46f1 100644 --- a/python/tvm/topi/nn/elemwise.py +++ b/python/tvm/topi/nn/elemwise.py @@ -140,4 +140,4 @@ def _compute_channelwise(*indices): xval = x(*indices) return tvm.tir.Select(xval > 0, xval, xval * slope(indices[axis])) - return te.compute(x.shape, _compute_channelwise) \ No newline at end of file + return te.compute(x.shape, _compute_channelwise) From 1093281a6adf18106a6a85fa36d3043e9aa840ee Mon Sep 17 00:00:00 2001 From: deivanayakisankaralingam Date: Tue, 8 Apr 2025 09:14:43 +0000 Subject: [PATCH 03/11] trailing whitespace issue fixed --- python/tvm/relax/frontend/torch/fx_translator.py | 6 +++--- python/tvm/topi/nn/elemwise.py | 2 +- tests/python/relax/test_frontend_from_exported_program.py | 6 +++--- tests/python/relax/test_frontend_from_fx.py | 6 +++--- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 499e9fc7774c..619a76163ef2 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -102,11 +102,11 @@ def _log_softmax_module(self, node: fx.Node) -> relax.Var: dim = module.dim assert dim is not None return self.block_builder.emit(relax.op.nn.log_softmax(x, dim)) - + def _prelu_module(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] - module = self.named_modules[node.target] - alpha_tensor = module.weight.numpy() + module = self.named_modules[node.target] + alpha_tensor = module.weight.numpy() alpha = relax.const(alpha_tensor, dtype="float32") axis = 0 if len(x.struct_info.shape) == 1 else 1 return self.block_builder.emit(relax.op.nn.prelu(x, alpha, axis)) diff --git a/python/tvm/topi/nn/elemwise.py b/python/tvm/topi/nn/elemwise.py index 05fa9f7a46f1..56aaf91d07e8 100644 --- a/python/tvm/topi/nn/elemwise.py +++ b/python/tvm/topi/nn/elemwise.py @@ -131,7 +131,7 @@ def prelu(x, slope, axis=1): assert axis < len(x.shape) slope = te.compute( (get_const_int(x.shape[axis]),), - lambda c: slope[0], + lambda c: slope[0], name="slope_broadcasted" ) assert get_const_int(slope.shape[0]) == get_const_int(x.shape[axis]) diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index fb263b81284a..fd12340d1671 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -647,7 +647,7 @@ class Prelu1(Module): def __init__(self, num_parameters=1, alpha=0.25): super().__init__() self.prelu = torch.nn.PReLU(num_parameters=num_parameters, init=alpha) - + def forward(self, x): return self.prelu(x) @@ -658,7 +658,7 @@ def __init__(self): def forward(self, x): return torch.nn.functional.prelu(x, self.alpha) - + @tvm.script.ir_module class expected: @R.function @@ -668,7 +668,7 @@ def main(x: R.Tensor((1, 3, 10, 10), dtype="float32")) -> R.Tuple(R.Tensor((1, 3 gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) R.output(gv) return gv - + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) verify_model(Prelu1(), example_args, {}, expected) verify_model(Prelu2(), example_args, {}, expected) diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index dd5923be9e08..21775aa8a580 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -828,7 +828,7 @@ class Prelu1(Module): def __init__(self, num_parameters=1, alpha=0.25): super().__init__() self.prelu = torch.nn.PReLU(num_parameters=num_parameters, init=alpha) - + def forward(self, x): return self.prelu(x) @@ -839,7 +839,7 @@ def __init__(self): def forward(self, x): return torch.nn.functional.prelu(x, self.alpha) - + @tvm.script.ir_module class expected: @R.function @@ -849,7 +849,7 @@ def main(x: R.Tensor((1, 3, 10, 10), dtype="float32")) -> R.Tensor((1, 3, 10, 10 gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv R.output(gv) return gv - + input_info = [([1, 3, 10, 10], "float32")] verify_model(Prelu1(), input_info, {}, expected) verify_model(Prelu2(), input_info, {}, expected) From 428ff3bd56416338a25cdbe4fb41308e21f286b6 Mon Sep 17 00:00:00 2001 From: deivanayakisankaralingam Date: Tue, 8 Apr 2025 09:51:20 +0000 Subject: [PATCH 04/11] fixing lint issues --- include/tvm/relax/attrs/nn.h | 3 +-- python/tvm/relax/frontend/torch/fx_translator.py | 2 +- python/tvm/topi/nn/elemwise.py | 4 +--- src/relax/op/nn/nn.cc | 2 +- tests/python/relax/test_frontend_from_exported_program.py | 8 ++++++-- tests/python/relax/test_frontend_from_fx.py | 8 ++++++-- 6 files changed, 16 insertions(+), 11 deletions(-) diff --git a/include/tvm/relax/attrs/nn.h b/include/tvm/relax/attrs/nn.h index 163e41d87d68..e2ce2be6a882 100644 --- a/include/tvm/relax/attrs/nn.h +++ b/include/tvm/relax/attrs/nn.h @@ -473,8 +473,7 @@ struct PReluAttrs : public tvm::AttrsNode { int axis; TVM_DECLARE_ATTRS(PReluAttrs, "relax.attrs.PReluAttrs") { - TVM_ATTR_FIELD(axis) - .describe("The axis along which the alpha values are applied."); + TVM_ATTR_FIELD(axis).describe("The axis along which the alpha values are applied."); } }; diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 619a76163ef2..8cfa72a9b50c 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -666,7 +666,7 @@ def create_convert_map( "logical_not": self._unary_op(relax.op.logical_not), "log_softmax": self._log_softmax, "neg": self._unary_op(relax.op.negative), - "prelu":self._prelu, + "prelu": self._prelu, "reciprocal": self._reciprocal, "relu": self._unary_op(relax.op.nn.relu), "round": self._round, diff --git a/python/tvm/topi/nn/elemwise.py b/python/tvm/topi/nn/elemwise.py index 56aaf91d07e8..59cc3598e9f2 100644 --- a/python/tvm/topi/nn/elemwise.py +++ b/python/tvm/topi/nn/elemwise.py @@ -130,9 +130,7 @@ def prelu(x, slope, axis=1): assert len(slope.shape) == 1 assert axis < len(x.shape) slope = te.compute( - (get_const_int(x.shape[axis]),), - lambda c: slope[0], - name="slope_broadcasted" + (get_const_int(x.shape[axis]),), lambda c: slope[0], name="slope_broadcasted" ) assert get_const_int(slope.shape[0]) == get_const_int(x.shape[axis]) diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc index 797b98b0afc5..cdc1fbf2db96 100644 --- a/src/relax/op/nn/nn.cc +++ b/src/relax/op/nn/nn.cc @@ -98,7 +98,7 @@ TVM_REGISTER_OP("relax.nn.prelu") .add_argument("alpha", "Tensor", "The channel-wise learnable slope.") .set_attrs_type() .set_attr("FInferStructInfo", - InferStructInfoUnaryArith) + InferStructInfoUnaryArith) .set_attr("FPurity", Bool(true)); /* relax.nn.softmax */ diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index fd12340d1671..e4694efa5617 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -662,9 +662,13 @@ def forward(self, x): @tvm.script.ir_module class expected: @R.function - def main(x: R.Tensor((1, 3, 10, 10), dtype="float32")) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + def main( + x: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): with R.dataflow(): - lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.prelu(x, R.const([0.25], dtype="float32"), axis=1) + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.prelu( + x, R.const([0.25], dtype="float32"), axis=1 + ) gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) R.output(gv) return gv diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 21775aa8a580..caecce4979b5 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -843,9 +843,13 @@ def forward(self, x): @tvm.script.ir_module class expected: @R.function - def main(x: R.Tensor((1, 3, 10, 10), dtype="float32")) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + def main( + x: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): with R.dataflow(): - lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.prelu(x, R.const([0.25], dtype="float32"), axis=1) + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.prelu( + x, R.const([0.25], dtype="float32"), axis=1 + ) gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv R.output(gv) return gv From edf511510b41899d804bd4d51e6111d723aae785 Mon Sep 17 00:00:00 2001 From: deivanayakisankaralingam Date: Tue, 8 Apr 2025 11:40:40 +0000 Subject: [PATCH 05/11] fix assertion error in test_op_nn.py file --- tests/python/relax/test_op_nn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/relax/test_op_nn.py b/tests/python/relax/test_op_nn.py index 0baea433fc48..1c03d8fe4649 100644 --- a/tests/python/relax/test_op_nn.py +++ b/tests/python/relax/test_op_nn.py @@ -37,7 +37,7 @@ def test_op_correctness(): x = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32")) alpha = relax.Var("alpha", R.Tensor((3,), "float32")) - assert relax.op.nn.prelu(x, alpha, axis=1) == Op.get("relax.nn.prelu") + assert relax.op.nn.prelu(x, alpha, axis=1).op == Op.get("relax.nn.prelu") x = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32")) gamma = relax.Var("gamma", R.Tensor((3,), "float32")) From bf2ba378a1dd369d5fd2fa99e0d7d2386f5bcb5d Mon Sep 17 00:00:00 2001 From: deivanayakisankaralingam Date: Wed, 9 Apr 2025 05:08:02 +0000 Subject: [PATCH 06/11] add test script in test_frontend_nn_op.py --- tests/python/relax/test_frontend_nn_op.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/python/relax/test_frontend_nn_op.py b/tests/python/relax/test_frontend_nn_op.py index ed81aa49ed34..58fdcd472e2e 100644 --- a/tests/python/relax/test_frontend_nn_op.py +++ b/tests/python/relax/test_frontend_nn_op.py @@ -393,6 +393,7 @@ def test(self, x: Tensor, weight: Tensor, bias: Tensor): negative_out = op.negative(x) softplus_out = op.softplus(x, beta=1.0, threshold=20.0) softmax_out = op.softmax(x, axis=2) + prelu_out = op.prelu(x, alpha=bias) rms_norm_out = op.rms_norm(x, weight, axes=[-2, -1]) rms_norm_with_bias_out = op.rms_norm(x, weight, axes=[-2, -1]) group_norm_out = op.group_norm(x, num_groups=1, weight=bias, bias=bias) @@ -411,6 +412,7 @@ def test( silu: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.silu(x) gelu: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.gelu(x) sigmoid: R.Tensor((2, 3, 4, 5), dtype="float32") = R.sigmoid(x) + prelu: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.prelu(x, bias) tanh: R.Tensor((2, 3, 4, 5), dtype="float32") = R.tanh(x) exp: R.Tensor((2, 3, 4, 5), dtype="float32") = R.exp(x) negative: R.Tensor((2, 3, 4, 5), dtype="float32") = R.negative(x) From 5375b49192b9036ded6e0382393f539274b1fe13 Mon Sep 17 00:00:00 2001 From: deivanayakisankaralingam Date: Wed, 9 Apr 2025 05:43:30 +0000 Subject: [PATCH 07/11] include wrapper function for prelu in op.py --- python/tvm/relax/frontend/nn/op.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/python/tvm/relax/frontend/nn/op.py b/python/tvm/relax/frontend/nn/op.py index e81ff7c5ad2c..86be98cba786 100644 --- a/python/tvm/relax/frontend/nn/op.py +++ b/python/tvm/relax/frontend/nn/op.py @@ -1072,6 +1072,34 @@ def softplus(x: Tensor, beta: float = 1.0, threshold: float = 20.0, name: str = return wrap_nested(_op.nn.softplus(x._expr, beta=beta, threshold=threshold), name) +def prelu(x: Tensor, alpha: Tensor, name: str = "prelu"): + r"""Parametric ReLU activation function. + + .. math:: + \text{PReLU}(x) = \begin{cases} + x & \text{if } x \geq 0 \\ + \alpha \cdot x & \text{if } x < 0 + \end{cases} + + Parameters + ---------- + x : Tensor + The input data. + + alpha : Tensor + Slope coefficient for the negative part of the input. + + name : str, optional + Optional name for the operation. Default is "prelu". + + Returns + ------- + result : Tensor + The computed result. + """ + return wrap_nested(_op.nn.prelu(x._expr, alpha._expr), name) + + def tanh(x: Tensor, name: str = "tanh") -> Tensor: r"""Applies the hyperbolic tangent function. From 424cf8c207a4348363cef51389a982790e246428 Mon Sep 17 00:00:00 2001 From: deivanayakisankaralingam Date: Wed, 9 Apr 2025 06:40:29 +0000 Subject: [PATCH 08/11] fixing unity check issue by modifying test func --- tests/python/relax/test_frontend_nn_op.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/relax/test_frontend_nn_op.py b/tests/python/relax/test_frontend_nn_op.py index 58fdcd472e2e..cc09998443b2 100644 --- a/tests/python/relax/test_frontend_nn_op.py +++ b/tests/python/relax/test_frontend_nn_op.py @@ -412,7 +412,6 @@ def test( silu: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.silu(x) gelu: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.gelu(x) sigmoid: R.Tensor((2, 3, 4, 5), dtype="float32") = R.sigmoid(x) - prelu: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.prelu(x, bias) tanh: R.Tensor((2, 3, 4, 5), dtype="float32") = R.tanh(x) exp: R.Tensor((2, 3, 4, 5), dtype="float32") = R.exp(x) negative: R.Tensor((2, 3, 4, 5), dtype="float32") = R.negative(x) @@ -420,6 +419,7 @@ def test( x, beta=1.0, threshold=20.0 ) softmax: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.softmax(x, axis=2) + prelu: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.prelu(x, bias) rms_norm: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.rms_norm( x, weight, axes=[-2, -1], epsilon=1.0000000000000001e-05 ) From 6ef3b513980dcac38b24ec14468304875bd52c6e Mon Sep 17 00:00:00 2001 From: deivanayakisankaralingam Date: Wed, 9 Apr 2025 16:28:36 +0000 Subject: [PATCH 09/11] conflicts resolved --- python/tvm/relax/op/nn/nn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py index eb76e4d9bb14..9d9eb3ef4820 100644 --- a/python/tvm/relax/op/nn/nn.py +++ b/python/tvm/relax/op/nn/nn.py @@ -1446,7 +1446,7 @@ def prelu(data: Expr, alpha: Expr, axis: int = 1) -> Expr: The learnable slope tensor, applied channel-wise. axis : int - The axis along which the `alpha` values are applied. + The axis along which the `alpha` values are applied Default is 1 (assuming NCHW format). Returns From e6eea8f3998ee50fbb46c796b6ea669f8d8922ac Mon Sep 17 00:00:00 2001 From: deivanayakisankaralingam Date: Thu, 10 Apr 2025 09:39:46 +0000 Subject: [PATCH 10/11] add doc for prelu op axis arg --- python/tvm/relax/frontend/torch/fx_translator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 8cfa72a9b50c..0bb414e43569 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -108,7 +108,7 @@ def _prelu_module(self, node: fx.Node) -> relax.Var: module = self.named_modules[node.target] alpha_tensor = module.weight.numpy() alpha = relax.const(alpha_tensor, dtype="float32") - axis = 0 if len(x.struct_info.shape) == 1 else 1 + axis = 0 if len(x.struct_info.shape) == 1 else 1 # Extract Channel size return self.block_builder.emit(relax.op.nn.prelu(x, alpha, axis)) def _softmax_module(self, node: fx.Node) -> relax.Var: From cd8f9cd0770fe55ddae837fa5237994df0921327 Mon Sep 17 00:00:00 2001 From: deivanayakisankaralingam Date: Thu, 10 Apr 2025 10:15:05 +0000 Subject: [PATCH 11/11] fixed failing checks issue --- python/tvm/relax/frontend/torch/fx_translator.py | 2 +- src/relax/op/nn/nn.cc | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 0bb414e43569..a26185ce3caa 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -108,7 +108,7 @@ def _prelu_module(self, node: fx.Node) -> relax.Var: module = self.named_modules[node.target] alpha_tensor = module.weight.numpy() alpha = relax.const(alpha_tensor, dtype="float32") - axis = 0 if len(x.struct_info.shape) == 1 else 1 # Extract Channel size + axis = 0 if len(x.struct_info.shape) == 1 else 1 # Extract Channel size return self.block_builder.emit(relax.op.nn.prelu(x, alpha, axis)) def _softmax_module(self, node: fx.Node) -> relax.Var: diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc index cdc1fbf2db96..8c0b86fe5f8e 100644 --- a/src/relax/op/nn/nn.cc +++ b/src/relax/op/nn/nn.cc @@ -79,6 +79,7 @@ TVM_REGISTER_OP("relax.nn.softplus") .set_attrs_type() .set_attr("FInferStructInfo", InferStructInfoUnaryArith) + .set_attr("FPurity", Bool(true)); /* relax.nn.prelu */ TVM_REGISTER_NODE_TYPE(PReluAttrs);