diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 3fc202a7cc91..9c68bacf60f4 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -670,6 +670,21 @@ def ones_like(self, inputs, input_types): return out + def new_zeros(self, inputs, input_types): + data = inputs[1] + + import torch + + if not isinstance(data, (_expr.Expr, list, torch.Tensor, np.ndarray)): + msg = "Data type %s could not be parsed in zeros op" % (type(data)) + raise AssertionError(msg) + + if inputs[2] is not None: + dtype = _convert_dtype_value(inputs[2]) + else: + dtype = input_types[0] + return self.full_impl(data, 0, dtype) + def zeros(self, inputs, input_types): data = inputs[0] @@ -1305,7 +1320,10 @@ def tensortonum(self, inputs, input_types): return inputs[0] def view(self, inputs, input_types): - data = inputs[0] + if isinstance(inputs[0], _expr.Expr): + data = inputs[0] + else: + data = _op.cast(_wrap_const(inputs[0]), input_types[0]) if len(inputs) == 3: shape_inp = [inputs[1], self.infer_shape(inputs[2])[0]] @@ -1384,6 +1402,10 @@ def clone(self, inputs, input_types): data = inputs[0] return _op.tensor.copy(data) + def copy(self, inputs, input_types): + src = inputs[1] + return _op.tensor.copy(src) + def log_softmax(self, inputs, input_types): data = inputs[0] axis = int(inputs[1]) @@ -2819,6 +2841,7 @@ def create_convert_map(self): "aten::addcmul": self.addcmul, "aten::ones": self.ones, "aten::ones_like": self.ones_like, + "aten::new_zeros": self.new_zeros, "aten::zeros": self.zeros, "aten::zeros_like": self.zeros_like, "aten::full": self.full, @@ -2878,6 +2901,7 @@ def create_convert_map(self): "aten::view": self.view, "aten::reshape": self.reshape, "aten::clone": self.clone, + "aten::copy_": self.copy, "aten::log_softmax": self.log_softmax, "aten::sigmoid": self.sigmoid, "aten::sigmoid_": self.sigmoid,