diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 33f6ffc3132e..2453cccf329b 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -417,6 +417,20 @@ def _rsub(self, node: fx.Node) -> relax.Var: return self.block_builder.emit(relax.op.subtract(rhs, lhs)) + def _isin(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + elements = args[0] + test_elements = args[1] + + expanded_elements = relax.op.expand_dims(elements, axis=-1) + flattened_test_elements = relax.op.reshape(test_elements, (-1,)) + + comparison = relax.op.equal(expanded_elements, flattened_test_elements) + summed = relax.op.sum(comparison, axis=-1) + result = relax.op.greater(summed, relax.const(0, dtype=elements.struct_info.dtype)) + + return self.block_builder.emit(result) + ########## Neural Network ########## def _adaptive_avg_pool2d(self, node: fx.Node) -> relax.Var: diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 0434712050ed..e489a223388d 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -299,6 +299,7 @@ def create_convert_map( "hardtanh_.default": self._hardtanh, "isfinite.default": self._unary_op(relax.op.isfinite), "isinf.default": self._unary_op(relax.op.isinf), + "isin.Tensor_Tensor": self._isin, "isnan.default": self._unary_op(relax.op.isnan), "leaky_relu.default": self._leakyrelu, "leaky_relu_.default": self._leakyrelu, diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 55abf20fcc03..113f12fea734 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -693,6 +693,7 @@ def create_convert_map( "hardtanh": self._hardtanh, "isfinite": self._unary_op(relax.op.isfinite), "isinf": self._unary_op(relax.op.isinf), + "isin": self._isin, "isnan": self._unary_op(relax.op.isnan), "leaky_relu": self._leakyrelu, "log": self._unary_op(relax.op.log), diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 108617991b1f..c040bc2ef3fa 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -1060,6 +1060,37 @@ def main( verify_model(RSub2(), example_args2, {}, expected_rsub2) +# IsIn + + +def test_isin(): + class IsInModel(torch.nn.Module): + def forward(self, x, test_elements): + return torch.isin(x, test_elements) + + @tvm.script.ir_module + class expected: + @R.function + def main( + x: R.Tensor((10, 10), dtype="float32"), test_elements: R.Tensor((8,), dtype="float32") + ) -> R.Tuple(R.Tensor((10, 10), dtype="bool")): + with R.dataflow(): + lv: R.Tensor((10, 10, 1), dtype="float32") = R.expand_dims(x, axis=[-1]) + lv1: R.Tensor((8,), dtype="float32") = R.reshape(test_elements, R.shape([8])) + lv2: R.Tensor((10, 10, 8), dtype="bool") = R.equal(lv, lv1) + lv3: R.Tensor((10, 10), dtype="bool") = R.sum(lv2, axis=[-1], keepdims=False) + lv4: R.Tensor((10, 10), dtype="bool") = R.greater(lv3, R.const(0.0, "float32")) + gv: R.Tuple(R.Tensor((10, 10), dtype="bool")) = (lv4,) + R.output(gv) + return gv + + example_args = ( + torch.randn(10, 10, dtype=torch.float32), + torch.randn(8, dtype=torch.float32), + ) + verify_model(IsInModel(), example_args, {}, expected) + + def test_batchnorm2d(): class BatchNorm2d(Module): def __init__(self): diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index cb69398e0a00..0ce7dc529fb0 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -1868,6 +1868,35 @@ def main(x: R.Tensor((10, 10), dtype="float32")) -> R.Tensor((10, 10), dtype="fl verify_model(RSub2(), input_info2, {}, expected_rsub2) +# IsIn + + +def test_isin(): + input_info = [([10, 10], "float32"), ([8], "float32")] + + class IsInModel(torch.nn.Module): + def forward(self, x, test_elements): + return torch.isin(x, test_elements) + + @tvm.script.ir_module + class expected: + @R.function + def main( + inp_0: R.Tensor((10, 10), dtype="float32"), inp_1: R.Tensor((8,), dtype="float32") + ) -> R.Tensor((10, 10), dtype="bool"): + with R.dataflow(): + lv: R.Tensor((10, 10, 1), dtype="float32") = R.expand_dims(inp_0, axis=[-1]) + lv1: R.Tensor((8,), dtype="float32") = R.reshape(inp_1, R.shape([8])) + lv2: R.Tensor((10, 10, 8), dtype="bool") = R.equal(lv, lv1) + lv3: R.Tensor((10, 10), dtype="bool") = R.sum(lv2, axis=[-1], keepdims=False) + lv4: R.Tensor((10, 10), dtype="bool") = R.greater(lv3, R.const(0.0, "float32")) + gv: R.Tensor((10, 10), dtype="bool") = lv4 + R.output(gv) + return gv + + verify_model(IsInModel(), input_info, {}, expected) + + def test_size(): input_info = [([1, 3, 10, 10], "float32")]