Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,6 +701,21 @@ def ones_like(self, inputs, input_types):

return out

def new_ones(self, inputs, input_types):
size = inputs[1]

import torch

if not isinstance(size, (_expr.Expr, list, tuple, torch.Size, np.ndarray)):
msg = "Data type %s could not be parsed in ones op" % (type(size))
raise AssertionError(msg)

if inputs[2] is not None:
dtype = _convert_dtype_value(inputs[2])
else:
dtype = input_types[0]
return self.full_impl(size, 1, dtype)

def zeros(self, inputs, input_types):
data = inputs[0]

Expand Down Expand Up @@ -765,6 +780,28 @@ def full_like(self, inputs, input_types):

return out

def new_full(self, inputs, input_types):
data = inputs[1]
fill_value = inputs[2]
import torch

if not isinstance(data, (_expr.Expr, list, tuple, torch.Size)):
msg = "Data type %s could not be parsed in full op" % (type(data))
raise AssertionError(msg)

if inputs[3] is not None: # dtype given
dtype = _convert_dtype_value(inputs[3])
else:
# if dtype is None, use the dtype of the input tensor
dtype = self.infer_type(input[0])

return self.full_impl(data, fill_value, dtype)

def fill_(self, inputs, input_types):
data = inputs[0]
fill_value = inputs[1]
return self.full_impl(self.infer_shape(data), fill_value, input_types[0])

def linspace(self, inputs, input_types):
start = inputs[0]
stop = inputs[1]
Expand Down Expand Up @@ -1397,6 +1434,11 @@ def reshape(self, inputs, input_types):
new_shape = tmp_shape
return _op.transform.reshape(data, new_shape)

def reshape_as(self, inputs, input_types):
data = inputs[0]
new_shape = self.infer_shape(inputs[1])
return _op.transform.reshape(data, new_shape)

def pixel_shuffle(self, inputs, input_types):
data = inputs[0]
upscale_factor = inputs[1]
Expand Down Expand Up @@ -2336,6 +2378,14 @@ def empty(self, inputs, input_types):
shape = inputs[0]
return _op.zeros(shape, _convert_dtype_value(inputs[1]))

def empty_like(self, inputs, input_types):
shape = self.infer_shape(inputs[0])
if inputs[1] is not None:
dtype = _convert_dtype_value(inputs[1])
else:
dtype = input_types[0]
return _op.zeros(shape, dtype)

def bincount(self, inputs, input_types):
data = inputs[0]
weights = inputs[1]
Expand Down Expand Up @@ -3055,8 +3105,11 @@ def create_convert_map(self):
"aten::ones_like": self.ones_like,
"aten::zeros": self.zeros,
"aten::zeros_like": self.zeros_like,
"aten::new_ones": self.new_ones,
"aten::full": self.full,
"aten::full_like": self.full_like,
"aten::new_full": self.new_full,
"aten::fill_": self.fill_,
"aten::linspace": self.linspace,
"aten::reciprocal": self.reciprocal,
"aten::repeat": self.repeat,
Expand Down Expand Up @@ -3121,6 +3174,7 @@ def create_convert_map(self):
"aten::size": self.size,
"aten::view": self.view,
"aten::reshape": self.reshape,
"aten::reshape_as": self.reshape_as,
"aten::clone": self.clone,
"aten::log_softmax": self.log_softmax,
"aten::sigmoid": self.sigmoid,
Expand Down Expand Up @@ -3239,6 +3293,7 @@ def create_convert_map(self):
"aten::tensor": self.identity, # used for example in tensor(1.0)
"aten::numel": self.numel,
"aten::empty": self.empty,
"aten::empty_like": self.empty_like,
"aten::bincount": self.bincount,
"aten::scatter_add": self.scatter_add,
"aten::__not__": self.logical_not,
Expand Down
75 changes: 75 additions & 0 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,28 @@ def visit(op):
torch.cuda.empty_cache()


def verify_model_with_input(test_func, input_data, input_dict={}):
baseline_outputs = test_func(*input_data)
trace = torch.jit.trace(test_func, [input.clone() for input in input_data])
input_names = ["input{}".format(idx) for idx, inp in enumerate(input_data)]
input_shapes = list(zip(input_names, [inp.shape for inp in input_data]))
mod, params = relay.frontend.from_pytorch(trace, input_shapes, {})
with tvm.transform.PassContext(opt_level=3):
for target in ["llvm", "cuda"]:
if not tvm.runtime.enabled(target):
continue
dev = tvm.device(target, 0)
lib = relay.build(mod, target=target, params=params)
relay_model = graph_executor.GraphModule(lib["default"](dev))
for name, value in input_dict.items():
relay_model.set_input(name, value)
relay_model.run()

compiled_output = relay_model.get_output(0).numpy()
assert_shapes_match(baseline_outputs, compiled_output)
tvm.testing.assert_allclose(baseline_outputs, compiled_output, rtol=1e-5, atol=1e-5)


# Single operator tests
@tvm.testing.uses_gpu
def test_forward_pixel_shuffle():
Expand Down Expand Up @@ -1275,6 +1297,16 @@ def forward(self, x):
verify_model(Reshape3(), input_data=torch.randn(2, 3, 4))


@tvm.testing.uses_gpu
def test_forward_reshape_as():
def test_func(input_tensor, other_tensor):
return input_tensor.reshape_as(other_tensor)

input_data = [torch.rand([2, 1, 10, 1, 10]), torch.rand([2, 1, 10, 10])]

verify_model_with_input(test_func, input_data, {"input0": input_data[0]})


@tvm.testing.uses_gpu
def test_flatten():
def _test_flatten(start_dim, end_dim):
Expand Down Expand Up @@ -2961,6 +2993,17 @@ def forward(self, *args):
verify_model(OnesLike3().float().eval(), input_data=input_data)


@tvm.testing.uses_gpu
def test_forward_new_ones():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10]

def test_func(input_tensor):
return input_tensor.new_ones([3, 10, 10])

verify_model_with_input(test_func, [torch.rand(input_shape).float()])


@tvm.testing.uses_gpu
def test_forward_zeros():
torch.set_grad_enabled(False)
Expand Down Expand Up @@ -3034,6 +3077,24 @@ def forward(self, *args):
verify_model(FullLike3().float().eval(), input_data=input_data)


@tvm.testing.uses_gpu
def test_forward_new_full():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10]

def test_func(input_tensor):
return input_tensor.new_full([2, 3], 1)

verify_model_with_input(test_func, [torch.rand(input_shape).float()])


def test_forward_fill_():
def test_func(x):
return x.fill_(3)

verify_model_with_input(test_func, [torch.rand([1, 3, 10, 10]).float()])


@tvm.testing.uses_gpu
def test_forward_linspace():
torch.set_grad_enabled(False)
Expand Down Expand Up @@ -3752,6 +3813,20 @@ def forward(self, data):
verify_script_model(Numel(), [(3, 5, 8)], targets)


def test_empty():
def test_func():
return torch.empty([1, 3, 10, 10])

verify_model_with_input(test_func, [])


def test_empty_like():
def test_func(data):
return torch.empty_like(data)

verify_model_with_input(test_func, [torch.rand([1, 3, 10, 10]).float()])


def test_forward_pretrained_bert_base_uncased():
######################################################################
# This is an example how to run BERT models using TVM
Expand Down