From 0b7f7329a60b9a013096437af17d29a81b1895e7 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Fri, 10 Nov 2023 17:21:35 +0900 Subject: [PATCH] add support for aten::bitwise_and --- python/tvm/relay/frontend/pytorch.py | 9 +++++++ tests/python/frontend/pytorch/test_forward.py | 27 +++++++++++++++++++ 2 files changed, 36 insertions(+) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 402ab592027c..bdfd8f78b22e 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2327,6 +2327,14 @@ def bitwise_xor(self, inputs, input_types): return _op.bitwise_xor(lhs, rhs) + def bitwise_and(self, inputs, input_types): + lhs = inputs[0] + rhs = inputs[1] + lhs = _op.cast(lhs, "bool") if input_types[0] == "bool" else _op.cast(lhs, "int") + rhs = _op.cast(rhs, "bool") if input_types[1] == "bool" else _op.cast(rhs, "int") + + return _op.bitwise_and(lhs, rhs) + def logical_not(self, inputs, input_types): data = _wrap_const(inputs[0]) return _op.logical_not(_op.cast(data, "bool")) @@ -4033,6 +4041,7 @@ def create_convert_map(self): "aten::logical_xor": self.logical_xor, "aten::bitwise_not": self.bitwise_not, "aten::bitwise_xor": self.bitwise_xor, + "aten::bitwise_and": self.bitwise_and, "aten::Bool": self.Bool, "aten::Float": self.Float, "aten::rsub": self.rsub, diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index b9c1b6ce9cd1..894bea60ed46 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -3695,6 +3695,33 @@ def forward(self, *args): verify_model(BitwiseXor2().float().eval(), input_data=[lhs]) +def test_forward_bitwise_and(): + """test_forward_bitwise_and""" + torch.set_grad_enabled(False) + + class BitwiseAnd1(Module): + def forward(self, *args): + return torch.bitwise_and(args[0], args[1]) + + class BitwiseAnd2(Module): + def forward(self, *args): + rhs = torch.tensor([1, 0, 3], dtype=torch.int8) + if torch.cuda.is_available(): + rhs = rhs.cuda() + return torch.bitwise_and(args[0], rhs) + + lhs = torch.tensor([-1, -2, 3], dtype=torch.int8) + rhs = torch.tensor([1, 0, 3], dtype=torch.int8) + verify_model(BitwiseAnd1().float().eval(), input_data=[lhs, rhs]) + + lhs = torch.tensor([True, True, False]) + rhs = torch.tensor([False, True, False]) + verify_model(BitwiseAnd1().float().eval(), input_data=[lhs, rhs]) + + lhs = torch.tensor([-1, -2, 3], dtype=torch.int8) + verify_model(BitwiseAnd2().float().eval(), input_data=[lhs]) + + @tvm.testing.uses_gpu def test_forward_logical_xor(): """test_forward_logical_xor"""