diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 091884379e00..ebc0132435ba 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2059,9 +2059,21 @@ def scatter_add(self, inputs, input_types): src = inputs[3] return _op.scatter_add(data, index, src, axis=axis) + def is_floating_point(self, inputs, input_types): + assert len(inputs) == 1 + + if isinstance(inputs[0], _expr.Expr): + input_type = self.infer_type(inputs[0]).dtype + else: + input_type = input_types[0] + + is_float = input_type in ["float32", "float64", "float16"] + return _expr.const(is_float) + # Operator mappings def create_convert_map(self): self.convert_map = { + "aten::is_floating_point": self.is_floating_point, "aten::pixel_shuffle": self.pixel_shuffle, "aten::device": self.none, "prim::device": self.none, @@ -2077,6 +2089,7 @@ def create_convert_map(self): "aten::div": self.make_elemwise("divide"), "aten::div_": self.make_elemwise("divide"), "aten::floor_divide": self.make_elemwise("floor_divide"), + "aten::floor_divide_": self.make_elemwise("floor_divide"), "aten::true_divide": self.make_elemwise("divide"), "aten::addcdiv": self.addcdiv, "aten::addcmul": self.addcmul,