diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index d1e202930d6d..f8634f5da70e 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -409,6 +409,32 @@ def call_binary_op(op, lhs, rhs): return convert + def _div(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + inp_1 = args[0] + inp_2 = args[1] + + # Handle scalar cases + if isinstance(inp_2, (int, float)): + inp_2 = relax.const(inp_2) + + # Get rounding_mode from node kwargs + rounding_mode = args[2] if len(node.args) > 2 else node.kwargs.get("rounding_mode", None) + + # Perform division based on rounding mode + if rounding_mode is None: + # True division (normal float division) + return self.block_builder.emit(relax.op.divide(inp_1, inp_2)) + elif rounding_mode == "floor": + # Floor division + return self.block_builder.emit(relax.op.floor_divide(inp_1, inp_2)) + elif rounding_mode == "trunc": + # Trunc division: perform true division then truncate + true_div = self.block_builder.emit(relax.op.divide(inp_1, inp_2)) + return self.block_builder.emit(relax.op.trunc(true_div)) + else: + raise ValueError(f"Unsupported rounding_mode: {rounding_mode}") + def _fmod(self, node: fx.Node): args = self.retrieve_args(node) lhs = args[0] diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 8b584906c808..42a57273af4e 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -336,6 +336,7 @@ def create_convert_map( "tanh.default": self._unary_op(relax.op.tanh), "tril.default": self._tril_triu(relax.op.tril), "triu.default": self._tril_triu(relax.op.triu), + "trunc.default": self._unary_op(relax.op.trunc), # binary "add.Tensor": self._binary_op(relax.op.add, operator.add), "add_.Tensor": self._binary_op(relax.op.add, operator.add), @@ -344,6 +345,7 @@ def create_convert_map( "bitwise_or_.Tensor": self._binary_op(relax.op.bitwise_or, operator.or_), "bitwise_or.Tensor": self._binary_op(relax.op.bitwise_or, operator.or_), "div.Tensor": self._binary_op(relax.op.divide, operator.truediv), + "div.Tensor_mode": self._div, "eq.Scalar": self._binary_op(relax.op.equal, operator.eq), "eq.Tensor": self._binary_op(relax.op.equal, operator.eq), "floor_divide.default": self._binary_op(relax.op.floor_divide, operator.floordiv), diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index f1223b6243cf..199e58cb1d9f 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -725,11 +725,13 @@ def create_convert_map( "tril": self._tril_triu(relax.op.tril), "triu_": self._inplace_tril_triu(relax.op.triu), "triu": self._tril_triu(relax.op.triu), + "trunc": self._unary_op(relax.op.trunc), # binary "add": self._binary_op(relax.op.add, operator.add), "and_": self._binary_op(relax.op.bitwise_and, operator.and_), "bitwise_or_": self._binary_op_inplace(relax.op.bitwise_or, operator.or_), "bitwise_or": self._binary_op(relax.op.bitwise_or, operator.or_), + "div": self._div, "eq": self._binary_op(relax.op.equal, operator.eq), "floordiv": self._binary_op(relax.op.floor_divide, operator.floordiv), "fmod": self._fmod, diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py index be5306c9f456..bfc0a997dfc8 100644 --- a/python/tvm/relax/op/__init__.py +++ b/python/tvm/relax/op/__init__.py @@ -150,6 +150,7 @@ square, tan, tanh, + trunc, ) diff --git a/python/tvm/relax/op/unary.py b/python/tvm/relax/op/unary.py index 11b78dbcc7e5..809ae24cad79 100644 --- a/python/tvm/relax/op/unary.py +++ b/python/tvm/relax/op/unary.py @@ -511,6 +511,20 @@ def tanh(x: Expr) -> Expr: return _ffi_api.tanh(x) # type: ignore +def trunc(x: Expr) -> Expr: + """Take trunc of input data. + Parameters + ---------- + x : relax.Expr + The input data + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.trunc(x) # type: ignore + + @args_converter.auto def clip(x: Expr, min: Expr, max: Expr) -> Expr: """Clips tensor values to a specified min and max. diff --git a/python/tvm/relax/transform/legalize_ops/unary.py b/python/tvm/relax/transform/legalize_ops/unary.py index 33752b9bd3a7..48b5f2f63046 100644 --- a/python/tvm/relax/transform/legalize_ops/unary.py +++ b/python/tvm/relax/transform/legalize_ops/unary.py @@ -50,6 +50,7 @@ register_legalize("relax.sqrt", _call_topi_without_attr(topi.sqrt, "tir_sqrt")) register_legalize("relax.tan", _call_topi_without_attr(topi.tan, "tir_tan")) register_legalize("relax.tanh", _call_topi_without_attr(topi.tanh, "tir_tanh")) +register_legalize("relax.trunc", _call_topi_without_attr(topi.trunc, "tir_trunc")) register_legalize("relax.clip", _call_topi_without_attr(topi.clip, "tir_clip")) diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index 5de8bfde72e5..d1e86cc7f456 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -174,6 +174,7 @@ topk, tril, triu, + trunc, unique, variance, vm, @@ -870,6 +871,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "to_vdevice", "tril", "triu", + "trunc", "tuple", "unique", "variance", diff --git a/src/relax/op/tensor/unary.cc b/src/relax/op/tensor/unary.cc index 64e4b00af56e..f95eb721fc70 100644 --- a/src/relax/op/tensor/unary.cc +++ b/src/relax/op/tensor/unary.cc @@ -62,6 +62,7 @@ RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(square, /*require_float_dtype=*/false); RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(sqrt, /*require_float_dtype=*/true); RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(tan, /*require_float_dtype=*/true); RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(tanh, /*require_float_dtype=*/true); +RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(trunc, /*require_float_dtype=*/false); RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(erf, /*require_float_dtype=*/true); // relax.clip diff --git a/src/relax/op/tensor/unary.h b/src/relax/op/tensor/unary.h index dfbad3789752..6984ba6304eb 100644 --- a/src/relax/op/tensor/unary.h +++ b/src/relax/op/tensor/unary.h @@ -133,6 +133,9 @@ Expr tan(Expr x); /*! \brief Compute element-wise tanh of data. */ Expr tanh(Expr x); +/*! \brief Take trunc of input data (round towards zero). */ +Expr trunc(Expr x); + /*! \brief Clips tensor values to a specified min and max. */ Expr clip(Expr x, Expr min, Expr max); diff --git a/src/target/intrin_rule.cc b/src/target/intrin_rule.cc index 9cb8f93841fe..3103e6f5b9c3 100644 --- a/src/target/intrin_rule.cc +++ b/src/target/intrin_rule.cc @@ -55,6 +55,9 @@ TVM_REGISTER_OP("tir.tanh") TVM_REGISTER_OP("tir.tan").set_attr("default.FLowerIntrinsic", DispatchPureExtern); +TVM_REGISTER_OP("tir.trunc") + .set_attr("default.FLowerIntrinsic", DispatchPureExtern); + TVM_REGISTER_OP("tir.atan") .set_attr("default.FLowerIntrinsic", DispatchPureExtern); diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index b07070ddc99f..ab3826b935c5 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -71,6 +71,7 @@ def verify_model(torch_model, example_args, binding, expected, dynamic_shapes=No (torch.square, R.square), (torch.tan, R.tan), (torch.tanh, R.tanh), + (torch.trunc, R.trunc), ] @@ -1092,6 +1093,70 @@ def main( verify_model(IsInModel(), example_args, {}, expected) +def test_div_mode(): + # Case 1: Basic division (no rounding mode) + class DivModel(torch.nn.Module): + def forward(self, a, b): + return torch.div(a, b) + + @tvm.script.ir_module + class expected_div: + @R.function + def main( + a: R.Tensor((64, 64), dtype="float32"), b: R.Tensor((64,), dtype="float32") + ) -> R.Tuple(R.Tensor((64, 64), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((64, 64), dtype="float32") = R.divide(a, b) + gv: R.Tuple(R.Tensor((64, 64), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = ( + torch.randn(64, 64, dtype=torch.float32), + torch.randn(64, dtype=torch.float32), + ) + verify_model(DivModel(), example_args, {}, expected_div) + + # Case 2: Division with trunc rounding + class DivTruncModel(torch.nn.Module): + def forward(self, a, b): + return torch.div(a, b, rounding_mode="trunc") + + @tvm.script.ir_module + class expected_div_trunc: + @R.function + def main( + a: R.Tensor((64, 64), dtype="float32"), b: R.Tensor((64,), dtype="float32") + ) -> R.Tuple(R.Tensor((64, 64), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((64, 64), dtype="float32") = R.divide(a, b) + lv1: R.Tensor((64, 64), dtype="float32") = R.trunc(lv) + gv: R.Tuple(R.Tensor((64, 64), dtype="float32")) = (lv1,) + R.output(gv) + return gv + + verify_model(DivTruncModel(), example_args, {}, expected_div_trunc) + + # Case 3: Division with floor rounding + class DivFloorModel(torch.nn.Module): + def forward(self, a, b): + return torch.div(a, b, rounding_mode="floor") + + @tvm.script.ir_module + class expected_div_floor: + @R.function + def main( + a: R.Tensor((64, 64), dtype="float32"), b: R.Tensor((64,), dtype="float32") + ) -> R.Tuple(R.Tensor((64, 64), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((64, 64), dtype="float32") = R.floor_divide(a, b) + gv: R.Tuple(R.Tensor((64, 64), dtype="float32")) = (lv,) + R.output(gv) + return gv + + verify_model(DivFloorModel(), example_args, {}, expected_div_floor) + + def test_batchnorm2d(): class BatchNorm2d(Module): def __init__(self): diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 2bb2a8444199..705181a02404 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -1934,6 +1934,66 @@ def main( verify_model(IsInModel(), input_info, {}, expected) +def test_div_mode(): + input_info = [([64, 64], "float32"), ([64, 64], "float32")] + + # Case 1: Basic division (no rounding mode) + class DivModel(torch.nn.Module): + def forward(self, x, y): + return torch.div(x, y) + + @tvm.script.ir_module + class expected_div: + @R.function + def main( + inp_0: R.Tensor((64, 64), dtype="float32"), inp_1: R.Tensor((64, 64), dtype="float32") + ) -> R.Tensor((64, 64), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((64, 64), dtype="float32") = R.divide(inp_0, inp_1) + gv: R.Tensor((64, 64), dtype="float32") = lv + R.output(gv) + return gv + + # Case 2: Division with trunc rounding + class DivTruncModel(torch.nn.Module): + def forward(self, x, y): + return torch.div(x, y, rounding_mode="trunc") + + @tvm.script.ir_module + class expected_div_trunc: + @R.function + def main( + inp_0: R.Tensor((64, 64), dtype="float32"), inp_1: R.Tensor((64, 64), dtype="float32") + ) -> R.Tensor((64, 64), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((64, 64), dtype="float32") = R.divide(inp_0, inp_1) + lv1: R.Tensor((64, 64), dtype="float32") = R.trunc(lv) + gv: R.Tensor((64, 64), dtype="float32") = lv1 + R.output(gv) + return gv + + # Case 3: Division with floor rounding + class DivFloorModel(torch.nn.Module): + def forward(self, x, y): + return torch.div(x, y, rounding_mode="floor") + + @tvm.script.ir_module + class expected_div_floor: + @R.function + def main( + inp_0: R.Tensor((64, 64), dtype="float32"), inp_1: R.Tensor((64, 64), dtype="float32") + ) -> R.Tensor((64, 64), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((64, 64), dtype="float32") = R.floor_divide(inp_0, inp_1) + gv: R.Tensor((64, 64), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(DivModel(), input_info, {}, expected_div) + verify_model(DivTruncModel(), input_info, {}, expected_div_trunc) + verify_model(DivFloorModel(), input_info, {}, expected_div_floor) + + def test_size(): input_info = [([1, 3, 10, 10], "float32")] @@ -2881,6 +2941,25 @@ def main( verify_model(Triu(), input_info, {}, expected_triu) verify_model(InplaceTriu(), input_info, {}, expected_triu) + # trunc + class Trunc(torch.nn.Module): + def forward(self, input): + return torch.trunc(input) + + @tvm.script.ir_module + class expected_trunc: + @R.function + def main( + inp_0: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.trunc(inp_0) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Trunc(), input_info, {}, expected_trunc) + def test_interpolate(): input_info = [([1, 3, 10, 10], "float32")]