Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 8 additions & 18 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,9 @@ def _impl(inputs, input_types):

end[dim] = min(end[dim], target_end)

strides.append(int(inputs[4]))
strides = [1] * len(end)
strides[dim] = int(inputs[4])

return _op.transform.strided_slice(data,
begin=_expr.const(begin),
end=_expr.const(end),
Expand Down Expand Up @@ -683,7 +685,7 @@ def _impl(inputs, input_types):
data = inputs[0]

pool_size = inputs[1]
strides = inputs[2]
strides = inputs[2] if inputs[2] else pool_size
padding = inputs[3]
dilation = inputs[4]
ceil_mode = int(inputs[5])
Expand All @@ -706,7 +708,7 @@ def _impl(inputs, input_types):
data = inputs[0]

pool_size = inputs[1]
strides = inputs[2]
strides = inputs[2] if inputs[2] else pool_size
padding = inputs[3]
dilation = inputs[4]
ceil_mode = int(inputs[5])
Expand All @@ -723,7 +725,7 @@ def _impl(inputs, input_types):
data = inputs[0]

pool_size = inputs[1]
strides = inputs[2]
strides = inputs[2] if inputs[2] else pool_size
padding = inputs[3]
dilation = inputs[4]
ceil_mode = int(inputs[5])
Expand Down Expand Up @@ -1173,14 +1175,8 @@ def _impl(inputs, input_types):
data = inputs[0]

pool_size = inputs[1]

if inputs[2]:
strides = inputs[2]
else:
strides = pool_size

strides = inputs[2] if inputs[2] else pool_size
padding = inputs[3]

ceil_mode = int(inputs[4])
count_include_pad = int(inputs[5])

Expand All @@ -1204,14 +1200,8 @@ def _impl(inputs, input_types):
data = inputs[0]

pool_size = inputs[1]

if inputs[2]:
strides = inputs[2]
else:
strides = pool_size

strides = inputs[2] if inputs[2] else pool_size
padding = inputs[3]

ceil_mode = int(inputs[4])
count_include_pad = int(inputs[5])

Expand Down
100 changes: 71 additions & 29 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,6 +674,13 @@ def test_forward_maxpool2d():
stride=2).eval(),
input_data)

# A functional variant (default strides = None case)
class MaxPool2D(Module):
def forward(self, *args):
return torch.nn.functional.max_pool2d(args[0], kernel_size=[10, 10])

verify_model(MaxPool2D(), input_data=input_data)

class MaxPool2DWithIndices(Module):
def __init__(self):
super(MaxPool2DWithIndices, self).__init__()
Expand All @@ -700,6 +707,14 @@ def test_forward_maxpool1d():
stride=2).eval(),
input_data)

# A functional variant (default strides = None case)
class MaxPool1D(Module):
def forward(self, *args):
return torch.nn.functional.max_pool1d(args[0], kernel_size=10)

verify_model(MaxPool1D(), input_data=input_data)


@tvm.testing.uses_gpu
def test_forward_maxpool3d():
torch.set_grad_enabled(False)
Expand All @@ -715,6 +730,14 @@ def test_forward_maxpool3d():
stride=2).eval(),
input_data)

# A functional variant (default strides = None case)
class MaxPool3D(Module):
def forward(self, *args):
return torch.nn.functional.max_pool3d(args[0], kernel_size=[10, 10, 10])

verify_model(MaxPool3D(), input_data=input_data)


@tvm.testing.uses_gpu
def test_forward_split():
torch.set_grad_enabled(False)
Expand Down Expand Up @@ -1323,10 +1346,20 @@ def forward(self, *args):
x1 = torch.tensor(3) + torch.tensor(1)
return args[0][:, x0:, 1:x1, :]

class SliceWithStride(torch.nn.Module):
def forward(self, x):
return x[..., 0::2] + x[..., 1::2]

class SliceWithStride2(torch.nn.Module):
def forward(self, x):
return x[0::2, 0::2] + x[1::2, 1::2]

input_data = torch.rand(input_shape).float()
verify_model(Slice1().float().eval(), input_data=input_data)
verify_model(Slice2().float().eval(), input_data=input_data)
verify_model(Slice3().float().eval(), input_data=input_data)
verify_model(Slice1(), input_data=input_data)
verify_model(Slice2(), input_data=input_data)
verify_model(Slice3(), input_data=input_data)
verify_model(SliceWithStride(), input_data=torch.randn(1, 4))
verify_model(SliceWithStride2(), input_data=torch.randn(4, 4))


@tvm.testing.uses_gpu
Expand Down Expand Up @@ -1584,9 +1617,11 @@ def _gen_rand_inputs(num_boxes):
scores = torch.rand(num_boxes, dtype=torch.float)
return boxes, scores

targets = ["llvm"] # dynamic nms does not work on gpu

for num_boxes, iou_thres in [(10, 0.3), (100, 0.5), (500, 0.9)]:
in_boxes, in_scores = _gen_rand_inputs(num_boxes)
verify_trace_model(NonMaxSupression(iou_thres), [in_boxes, in_scores])
verify_trace_model(NonMaxSupression(iou_thres), [in_boxes, in_scores], targets)


@tvm.testing.uses_gpu
Expand Down Expand Up @@ -1752,46 +1787,53 @@ def test_3d_models():
verify_model(resnet3d, [torch.rand(input_shape)], atol=1e-4, rtol=1e-4)


def verify_script_model(pt_model, ishapes):
def _get_default_vm_targets():
return [tgt for (tgt, _) in tvm.testing.enabled_targets()]


def verify_script_model(pt_model, ishapes, targets):
script_module = torch.jit.script(pt_model)
verify_model_vm(script_module, ishapes)
verify_model_vm(script_module, ishapes, targets=targets)


def verify_trace_model(pt_model, idata):
def verify_trace_model(pt_model, idata, targets):
traced_model = torch.jit.trace(pt_model, idata)
ishapes = [data.shape for data in idata]
verify_model_vm(traced_model, ishapes, idata=idata)
verify_model_vm(traced_model, ishapes, idata=idata, targets=targets)


def verify_model_vm(imodel, ishapes, idtype=torch.float, idata=None):
input_model = imodel
def verify_model_vm(input_model, ishapes, idtype=torch.float,
idata=None, targets=["llvm"]):
input_names = ["i{}".format(idx) for idx, ish in enumerate(ishapes)]
input_shapes = list(zip(input_names, ishapes))
input_data = idata if idata else [torch.randn(shape, dtype=idtype)
for shape in ishapes]
# Compile via VM
mod, params = relay.frontend.from_pytorch(input_model, input_shapes)

executor = relay.create_executor("vm", mod=mod, ctx=tvm.cpu(0),
target="llvm")
evaluator = executor.evaluate()
for tgt in targets:
print("Running on target", tgt)
ctx = tvm.context(tgt, 0)

# Inference
for name, inp in zip(input_names, input_data):
params[name] = inp.numpy()
vm_res = evaluator(**params)
executor = relay.create_executor("vm", mod=mod, ctx=ctx, target=tgt)
evaluator = executor.evaluate()

# Baseline result
with torch.no_grad():
pt_result = input_model(*input_data)
# Inference
for name, inp in zip(input_names, input_data):
params[name] = inp.numpy()
vm_res = evaluator(**params)

# Verify the accuracy
if not isinstance(pt_result, torch.Tensor):
tvm_res = vm_res.asnumpy().item()
assert pt_result == tvm_res
else:
tvm.testing.assert_allclose(vm_res.asnumpy(), pt_result.numpy(),
rtol=1e-5, atol=1e-5)
# Baseline result
with torch.no_grad():
pt_result = input_model(*input_data)

# Verify the accuracy
if not isinstance(pt_result, torch.Tensor):
tvm_res = vm_res.asnumpy().item()
assert pt_result == tvm_res
else:
tvm.testing.assert_allclose(vm_res.asnumpy(), pt_result.numpy(),
rtol=1e-5, atol=1e-5)


@tvm.testing.uses_gpu
Expand Down Expand Up @@ -1905,7 +1947,7 @@ def forward(self, inp):
]

for pt_model in models:
verify_script_model(pt_model.eval(), [(10, 20)])
verify_script_model(pt_model.eval(), [(10, 20)], _get_default_vm_targets())


@tvm.testing.uses_gpu
Expand Down Expand Up @@ -1943,7 +1985,7 @@ def forward(self, xs):
y, h = self.cell(xs[i], h)
return y

verify_script_model(RNNLoop().eval(), [(10, 10, 4)])
verify_script_model(RNNLoop().eval(), [(10, 10, 4)], _get_default_vm_targets())


@tvm.testing.uses_gpu
Expand Down