diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 3e0bf64e4c1c..2887184b79b8 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -701,6 +701,21 @@ def ones_like(self, inputs, input_types): return out + def new_ones(self, inputs, input_types): + size = inputs[1] + + import torch + + if not isinstance(size, (_expr.Expr, list, tuple, torch.Size, np.ndarray)): + msg = "Data type %s could not be parsed in ones op" % (type(size)) + raise AssertionError(msg) + + if inputs[2] is not None: + dtype = _convert_dtype_value(inputs[2]) + else: + dtype = input_types[0] + return self.full_impl(size, 1, dtype) + def zeros(self, inputs, input_types): data = inputs[0] @@ -765,6 +780,28 @@ def full_like(self, inputs, input_types): return out + def new_full(self, inputs, input_types): + data = inputs[1] + fill_value = inputs[2] + import torch + + if not isinstance(data, (_expr.Expr, list, tuple, torch.Size)): + msg = "Data type %s could not be parsed in full op" % (type(data)) + raise AssertionError(msg) + + if inputs[3] is not None: # dtype given + dtype = _convert_dtype_value(inputs[3]) + else: + # if dtype is None, use the dtype of the input tensor + dtype = self.infer_type(input[0]) + + return self.full_impl(data, fill_value, dtype) + + def fill_(self, inputs, input_types): + data = inputs[0] + fill_value = inputs[1] + return self.full_impl(self.infer_shape(data), fill_value, input_types[0]) + def linspace(self, inputs, input_types): start = inputs[0] stop = inputs[1] @@ -1397,6 +1434,11 @@ def reshape(self, inputs, input_types): new_shape = tmp_shape return _op.transform.reshape(data, new_shape) + def reshape_as(self, inputs, input_types): + data = inputs[0] + new_shape = self.infer_shape(inputs[1]) + return _op.transform.reshape(data, new_shape) + def pixel_shuffle(self, inputs, input_types): data = inputs[0] upscale_factor = inputs[1] @@ -2336,6 +2378,14 @@ def empty(self, inputs, input_types): shape = inputs[0] return _op.zeros(shape, _convert_dtype_value(inputs[1])) + def empty_like(self, inputs, input_types): + shape = self.infer_shape(inputs[0]) + if inputs[1] is not None: + dtype = _convert_dtype_value(inputs[1]) + else: + dtype = input_types[0] + return _op.zeros(shape, dtype) + def bincount(self, inputs, input_types): data = inputs[0] weights = inputs[1] @@ -3055,8 +3105,11 @@ def create_convert_map(self): "aten::ones_like": self.ones_like, "aten::zeros": self.zeros, "aten::zeros_like": self.zeros_like, + "aten::new_ones": self.new_ones, "aten::full": self.full, "aten::full_like": self.full_like, + "aten::new_full": self.new_full, + "aten::fill_": self.fill_, "aten::linspace": self.linspace, "aten::reciprocal": self.reciprocal, "aten::repeat": self.repeat, @@ -3121,6 +3174,7 @@ def create_convert_map(self): "aten::size": self.size, "aten::view": self.view, "aten::reshape": self.reshape, + "aten::reshape_as": self.reshape_as, "aten::clone": self.clone, "aten::log_softmax": self.log_softmax, "aten::sigmoid": self.sigmoid, @@ -3239,6 +3293,7 @@ def create_convert_map(self): "aten::tensor": self.identity, # used for example in tensor(1.0) "aten::numel": self.numel, "aten::empty": self.empty, + "aten::empty_like": self.empty_like, "aten::bincount": self.bincount, "aten::scatter_add": self.scatter_add, "aten::__not__": self.logical_not, diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index e4cb6354c017..3ad4aec77491 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -199,6 +199,28 @@ def visit(op): torch.cuda.empty_cache() +def verify_model_with_input(test_func, input_data, input_dict={}): + baseline_outputs = test_func(*input_data) + trace = torch.jit.trace(test_func, [input.clone() for input in input_data]) + input_names = ["input{}".format(idx) for idx, inp in enumerate(input_data)] + input_shapes = list(zip(input_names, [inp.shape for inp in input_data])) + mod, params = relay.frontend.from_pytorch(trace, input_shapes, {}) + with tvm.transform.PassContext(opt_level=3): + for target in ["llvm", "cuda"]: + if not tvm.runtime.enabled(target): + continue + dev = tvm.device(target, 0) + lib = relay.build(mod, target=target, params=params) + relay_model = graph_executor.GraphModule(lib["default"](dev)) + for name, value in input_dict.items(): + relay_model.set_input(name, value) + relay_model.run() + + compiled_output = relay_model.get_output(0).numpy() + assert_shapes_match(baseline_outputs, compiled_output) + tvm.testing.assert_allclose(baseline_outputs, compiled_output, rtol=1e-5, atol=1e-5) + + # Single operator tests @tvm.testing.uses_gpu def test_forward_pixel_shuffle(): @@ -1275,6 +1297,16 @@ def forward(self, x): verify_model(Reshape3(), input_data=torch.randn(2, 3, 4)) +@tvm.testing.uses_gpu +def test_forward_reshape_as(): + def test_func(input_tensor, other_tensor): + return input_tensor.reshape_as(other_tensor) + + input_data = [torch.rand([2, 1, 10, 1, 10]), torch.rand([2, 1, 10, 10])] + + verify_model_with_input(test_func, input_data, {"input0": input_data[0]}) + + @tvm.testing.uses_gpu def test_flatten(): def _test_flatten(start_dim, end_dim): @@ -2961,6 +2993,17 @@ def forward(self, *args): verify_model(OnesLike3().float().eval(), input_data=input_data) +@tvm.testing.uses_gpu +def test_forward_new_ones(): + torch.set_grad_enabled(False) + input_shape = [1, 3, 10, 10] + + def test_func(input_tensor): + return input_tensor.new_ones([3, 10, 10]) + + verify_model_with_input(test_func, [torch.rand(input_shape).float()]) + + @tvm.testing.uses_gpu def test_forward_zeros(): torch.set_grad_enabled(False) @@ -3034,6 +3077,24 @@ def forward(self, *args): verify_model(FullLike3().float().eval(), input_data=input_data) +@tvm.testing.uses_gpu +def test_forward_new_full(): + torch.set_grad_enabled(False) + input_shape = [1, 3, 10, 10] + + def test_func(input_tensor): + return input_tensor.new_full([2, 3], 1) + + verify_model_with_input(test_func, [torch.rand(input_shape).float()]) + + +def test_forward_fill_(): + def test_func(x): + return x.fill_(3) + + verify_model_with_input(test_func, [torch.rand([1, 3, 10, 10]).float()]) + + @tvm.testing.uses_gpu def test_forward_linspace(): torch.set_grad_enabled(False) @@ -3752,6 +3813,20 @@ def forward(self, data): verify_script_model(Numel(), [(3, 5, 8)], targets) +def test_empty(): + def test_func(): + return torch.empty([1, 3, 10, 10]) + + verify_model_with_input(test_func, []) + + +def test_empty_like(): + def test_func(data): + return torch.empty_like(data) + + verify_model_with_input(test_func, [torch.rand([1, 3, 10, 10]).float()]) + + def test_forward_pretrained_bert_base_uncased(): ###################################################################### # This is an example how to run BERT models using TVM