Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions python/tvm/relax/frontend/torch/base_fx_graph_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,32 @@ def call_binary_op(op, lhs, rhs):

return convert

def _div(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
inp_1 = args[0]
inp_2 = args[1]

# Handle scalar cases
if isinstance(inp_2, (int, float)):
inp_2 = relax.const(inp_2)

# Get rounding_mode from node kwargs
rounding_mode = args[2] if len(node.args) > 2 else node.kwargs.get("rounding_mode", None)

# Perform division based on rounding mode
if rounding_mode is None:
# True division (normal float division)
return self.block_builder.emit(relax.op.divide(inp_1, inp_2))
elif rounding_mode == "floor":
# Floor division
return self.block_builder.emit(relax.op.floor_divide(inp_1, inp_2))
elif rounding_mode == "trunc":
# Trunc division: perform true division then truncate
true_div = self.block_builder.emit(relax.op.divide(inp_1, inp_2))
return self.block_builder.emit(relax.op.trunc(true_div))
else:
raise ValueError(f"Unsupported rounding_mode: {rounding_mode}")

def _fmod(self, node: fx.Node):
args = self.retrieve_args(node)
lhs = args[0]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,7 @@ def create_convert_map(
"tanh.default": self._unary_op(relax.op.tanh),
"tril.default": self._tril_triu(relax.op.tril),
"triu.default": self._tril_triu(relax.op.triu),
"trunc.default": self._unary_op(relax.op.trunc),
# binary
"add.Tensor": self._binary_op(relax.op.add, operator.add),
"add_.Tensor": self._binary_op(relax.op.add, operator.add),
Expand All @@ -344,6 +345,7 @@ def create_convert_map(
"bitwise_or_.Tensor": self._binary_op(relax.op.bitwise_or, operator.or_),
"bitwise_or.Tensor": self._binary_op(relax.op.bitwise_or, operator.or_),
"div.Tensor": self._binary_op(relax.op.divide, operator.truediv),
"div.Tensor_mode": self._div,
"eq.Scalar": self._binary_op(relax.op.equal, operator.eq),
"eq.Tensor": self._binary_op(relax.op.equal, operator.eq),
"floor_divide.default": self._binary_op(relax.op.floor_divide, operator.floordiv),
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/relax/frontend/torch/fx_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -725,11 +725,13 @@ def create_convert_map(
"tril": self._tril_triu(relax.op.tril),
"triu_": self._inplace_tril_triu(relax.op.triu),
"triu": self._tril_triu(relax.op.triu),
"trunc": self._unary_op(relax.op.trunc),
# binary
"add": self._binary_op(relax.op.add, operator.add),
"and_": self._binary_op(relax.op.bitwise_and, operator.and_),
"bitwise_or_": self._binary_op_inplace(relax.op.bitwise_or, operator.or_),
"bitwise_or": self._binary_op(relax.op.bitwise_or, operator.or_),
"div": self._div,
"eq": self._binary_op(relax.op.equal, operator.eq),
"floordiv": self._binary_op(relax.op.floor_divide, operator.floordiv),
"fmod": self._fmod,
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relax/op/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@
square,
tan,
tanh,
trunc,
)


Expand Down
14 changes: 14 additions & 0 deletions python/tvm/relax/op/unary.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,20 @@ def tanh(x: Expr) -> Expr:
return _ffi_api.tanh(x) # type: ignore


def trunc(x: Expr) -> Expr:
"""Take trunc of input data.
Parameters
----------
x : relax.Expr
The input data
Returns
-------
result : relax.Expr
The computed result.
"""
return _ffi_api.trunc(x) # type: ignore


@args_converter.auto
def clip(x: Expr, min: Expr, max: Expr) -> Expr:
"""Clips tensor values to a specified min and max.
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relax/transform/legalize_ops/unary.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
register_legalize("relax.sqrt", _call_topi_without_attr(topi.sqrt, "tir_sqrt"))
register_legalize("relax.tan", _call_topi_without_attr(topi.tan, "tir_tan"))
register_legalize("relax.tanh", _call_topi_without_attr(topi.tanh, "tir_tanh"))
register_legalize("relax.trunc", _call_topi_without_attr(topi.trunc, "tir_trunc"))
register_legalize("relax.clip", _call_topi_without_attr(topi.clip, "tir_clip"))


Expand Down
2 changes: 2 additions & 0 deletions python/tvm/script/ir_builder/relax/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@
topk,
tril,
triu,
trunc,
unique,
variance,
vm,
Expand Down Expand Up @@ -870,6 +871,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr:
"to_vdevice",
"tril",
"triu",
"trunc",
"tuple",
"unique",
"variance",
Expand Down
1 change: 1 addition & 0 deletions src/relax/op/tensor/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(square, /*require_float_dtype=*/false);
RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(sqrt, /*require_float_dtype=*/true);
RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(tan, /*require_float_dtype=*/true);
RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(tanh, /*require_float_dtype=*/true);
RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(trunc, /*require_float_dtype=*/false);
RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(erf, /*require_float_dtype=*/true);

// relax.clip
Expand Down
3 changes: 3 additions & 0 deletions src/relax/op/tensor/unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,9 @@ Expr tan(Expr x);
/*! \brief Compute element-wise tanh of data. */
Expr tanh(Expr x);

/*! \brief Take trunc of input data (round towards zero). */
Expr trunc(Expr x);

/*! \brief Clips tensor values to a specified min and max. */
Expr clip(Expr x, Expr min, Expr max);

Expand Down
3 changes: 3 additions & 0 deletions src/target/intrin_rule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ TVM_REGISTER_OP("tir.tanh")
TVM_REGISTER_OP("tir.tan").set_attr<FLowerIntrinsic>("default.FLowerIntrinsic",
DispatchPureExtern<FloatSuffix>);

TVM_REGISTER_OP("tir.trunc")
.set_attr<FLowerIntrinsic>("default.FLowerIntrinsic", DispatchPureExtern<FloatSuffix>);

TVM_REGISTER_OP("tir.atan")
.set_attr<FLowerIntrinsic>("default.FLowerIntrinsic", DispatchPureExtern<FloatSuffix>);

Expand Down
65 changes: 65 additions & 0 deletions tests/python/relax/test_frontend_from_exported_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def verify_model(torch_model, example_args, binding, expected, dynamic_shapes=No
(torch.square, R.square),
(torch.tan, R.tan),
(torch.tanh, R.tanh),
(torch.trunc, R.trunc),
]


Expand Down Expand Up @@ -1092,6 +1093,70 @@ def main(
verify_model(IsInModel(), example_args, {}, expected)


def test_div_mode():
# Case 1: Basic division (no rounding mode)
class DivModel(torch.nn.Module):
def forward(self, a, b):
return torch.div(a, b)

@tvm.script.ir_module
class expected_div:
@R.function
def main(
a: R.Tensor((64, 64), dtype="float32"), b: R.Tensor((64,), dtype="float32")
) -> R.Tuple(R.Tensor((64, 64), dtype="float32")):
with R.dataflow():
lv: R.Tensor((64, 64), dtype="float32") = R.divide(a, b)
gv: R.Tuple(R.Tensor((64, 64), dtype="float32")) = (lv,)
R.output(gv)
return gv

example_args = (
torch.randn(64, 64, dtype=torch.float32),
torch.randn(64, dtype=torch.float32),
)
verify_model(DivModel(), example_args, {}, expected_div)

# Case 2: Division with trunc rounding
class DivTruncModel(torch.nn.Module):
def forward(self, a, b):
return torch.div(a, b, rounding_mode="trunc")

@tvm.script.ir_module
class expected_div_trunc:
@R.function
def main(
a: R.Tensor((64, 64), dtype="float32"), b: R.Tensor((64,), dtype="float32")
) -> R.Tuple(R.Tensor((64, 64), dtype="float32")):
with R.dataflow():
lv: R.Tensor((64, 64), dtype="float32") = R.divide(a, b)
lv1: R.Tensor((64, 64), dtype="float32") = R.trunc(lv)
gv: R.Tuple(R.Tensor((64, 64), dtype="float32")) = (lv1,)
R.output(gv)
return gv

verify_model(DivTruncModel(), example_args, {}, expected_div_trunc)

# Case 3: Division with floor rounding
class DivFloorModel(torch.nn.Module):
def forward(self, a, b):
return torch.div(a, b, rounding_mode="floor")

@tvm.script.ir_module
class expected_div_floor:
@R.function
def main(
a: R.Tensor((64, 64), dtype="float32"), b: R.Tensor((64,), dtype="float32")
) -> R.Tuple(R.Tensor((64, 64), dtype="float32")):
with R.dataflow():
lv: R.Tensor((64, 64), dtype="float32") = R.floor_divide(a, b)
gv: R.Tuple(R.Tensor((64, 64), dtype="float32")) = (lv,)
R.output(gv)
return gv

verify_model(DivFloorModel(), example_args, {}, expected_div_floor)


def test_batchnorm2d():
class BatchNorm2d(Module):
def __init__(self):
Expand Down
79 changes: 79 additions & 0 deletions tests/python/relax/test_frontend_from_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1934,6 +1934,66 @@ def main(
verify_model(IsInModel(), input_info, {}, expected)


def test_div_mode():
input_info = [([64, 64], "float32"), ([64, 64], "float32")]

# Case 1: Basic division (no rounding mode)
class DivModel(torch.nn.Module):
def forward(self, x, y):
return torch.div(x, y)

@tvm.script.ir_module
class expected_div:
@R.function
def main(
inp_0: R.Tensor((64, 64), dtype="float32"), inp_1: R.Tensor((64, 64), dtype="float32")
) -> R.Tensor((64, 64), dtype="float32"):
with R.dataflow():
lv: R.Tensor((64, 64), dtype="float32") = R.divide(inp_0, inp_1)
gv: R.Tensor((64, 64), dtype="float32") = lv
R.output(gv)
return gv

# Case 2: Division with trunc rounding
class DivTruncModel(torch.nn.Module):
def forward(self, x, y):
return torch.div(x, y, rounding_mode="trunc")

@tvm.script.ir_module
class expected_div_trunc:
@R.function
def main(
inp_0: R.Tensor((64, 64), dtype="float32"), inp_1: R.Tensor((64, 64), dtype="float32")
) -> R.Tensor((64, 64), dtype="float32"):
with R.dataflow():
lv: R.Tensor((64, 64), dtype="float32") = R.divide(inp_0, inp_1)
lv1: R.Tensor((64, 64), dtype="float32") = R.trunc(lv)
gv: R.Tensor((64, 64), dtype="float32") = lv1
R.output(gv)
return gv

# Case 3: Division with floor rounding
class DivFloorModel(torch.nn.Module):
def forward(self, x, y):
return torch.div(x, y, rounding_mode="floor")

@tvm.script.ir_module
class expected_div_floor:
@R.function
def main(
inp_0: R.Tensor((64, 64), dtype="float32"), inp_1: R.Tensor((64, 64), dtype="float32")
) -> R.Tensor((64, 64), dtype="float32"):
with R.dataflow():
lv: R.Tensor((64, 64), dtype="float32") = R.floor_divide(inp_0, inp_1)
gv: R.Tensor((64, 64), dtype="float32") = lv
R.output(gv)
return gv

verify_model(DivModel(), input_info, {}, expected_div)
verify_model(DivTruncModel(), input_info, {}, expected_div_trunc)
verify_model(DivFloorModel(), input_info, {}, expected_div_floor)


def test_size():
input_info = [([1, 3, 10, 10], "float32")]

Expand Down Expand Up @@ -2881,6 +2941,25 @@ def main(
verify_model(Triu(), input_info, {}, expected_triu)
verify_model(InplaceTriu(), input_info, {}, expected_triu)

# trunc
class Trunc(torch.nn.Module):
def forward(self, input):
return torch.trunc(input)

@tvm.script.ir_module
class expected_trunc:
@R.function
def main(
inp_0: R.Tensor((1, 3, 10, 10), dtype="float32")
) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
with R.dataflow():
lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.trunc(inp_0)
gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
R.output(gv)
return gv

verify_model(Trunc(), input_info, {}, expected_trunc)


def test_interpolate():
input_info = [([1, 3, 10, 10], "float32")]
Expand Down