From 2fe728cc02f979ae5b790e41df84813526e53d72 Mon Sep 17 00:00:00 2001 From: masa Date: Tue, 8 Sep 2020 04:36:33 +0900 Subject: [PATCH 1/3] fix strides conversion --- python/tvm/relay/frontend/pytorch.py | 4 +++- tests/python/frontend/pytorch/test_forward.py | 16 +++++++++++++--- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 99d2dae99e4f..8d326957bc78 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -305,7 +305,9 @@ def _impl(inputs, input_types): end[dim] = min(end[dim], target_end) - strides.append(int(inputs[4])) + strides = [1] * len(end) + strides[dim] = int(inputs[4]) + return _op.transform.strided_slice(data, begin=_expr.const(begin), end=_expr.const(end), diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index b35c7d6ddc2e..69ea7539443e 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -1323,10 +1323,20 @@ def forward(self, *args): x1 = torch.tensor(3) + torch.tensor(1) return args[0][:, x0:, 1:x1, :] + class SliceWithStride(torch.nn.Module): + def forward(self, x): + return x[..., 0::2] + x[..., 1::2] + + class SliceWithStride2(torch.nn.Module): + def forward(self, x): + return x[0::2, 0::2] + x[1::2, 1::2] + input_data = torch.rand(input_shape).float() - verify_model(Slice1().float().eval(), input_data=input_data) - verify_model(Slice2().float().eval(), input_data=input_data) - verify_model(Slice3().float().eval(), input_data=input_data) + verify_model(Slice1(), input_data=input_data) + verify_model(Slice2(), input_data=input_data) + verify_model(Slice3(), input_data=input_data) + verify_model(SliceWithStride(), input_data=torch.randn(1, 4)) + verify_model(SliceWithStride2(), input_data=torch.randn(4, 4)) @tvm.testing.uses_gpu From 0b5e0a817349ab551abae5ac38181872ca4bea80 Mon Sep 17 00:00:00 2001 From: masa Date: Tue, 8 Sep 2020 05:05:15 +0900 Subject: [PATCH 2/3] enable gpu target for some vm tests --- tests/python/frontend/pytorch/test_forward.py | 61 +++++++++++-------- 1 file changed, 35 insertions(+), 26 deletions(-) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 69ea7539443e..4215994715d2 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -1594,9 +1594,11 @@ def _gen_rand_inputs(num_boxes): scores = torch.rand(num_boxes, dtype=torch.float) return boxes, scores + targets = ["llvm"] # dynamic nms does not work on gpu + for num_boxes, iou_thres in [(10, 0.3), (100, 0.5), (500, 0.9)]: in_boxes, in_scores = _gen_rand_inputs(num_boxes) - verify_trace_model(NonMaxSupression(iou_thres), [in_boxes, in_scores]) + verify_trace_model(NonMaxSupression(iou_thres), [in_boxes, in_scores], targets) @tvm.testing.uses_gpu @@ -1762,19 +1764,23 @@ def test_3d_models(): verify_model(resnet3d, [torch.rand(input_shape)], atol=1e-4, rtol=1e-4) -def verify_script_model(pt_model, ishapes): +def _get_default_vm_targets(): + return [tgt for (tgt, _) in tvm.testing.enabled_targets()] + + +def verify_script_model(pt_model, ishapes, targets): script_module = torch.jit.script(pt_model) - verify_model_vm(script_module, ishapes) + verify_model_vm(script_module, ishapes, targets=targets) -def verify_trace_model(pt_model, idata): +def verify_trace_model(pt_model, idata, targets): traced_model = torch.jit.trace(pt_model, idata) ishapes = [data.shape for data in idata] - verify_model_vm(traced_model, ishapes, idata=idata) + verify_model_vm(traced_model, ishapes, idata=idata, targets=targets) -def verify_model_vm(imodel, ishapes, idtype=torch.float, idata=None): - input_model = imodel +def verify_model_vm(input_model, ishapes, idtype=torch.float, + idata=None, targets=["llvm"]): input_names = ["i{}".format(idx) for idx, ish in enumerate(ishapes)] input_shapes = list(zip(input_names, ishapes)) input_data = idata if idata else [torch.randn(shape, dtype=idtype) @@ -1782,26 +1788,29 @@ def verify_model_vm(imodel, ishapes, idtype=torch.float, idata=None): # Compile via VM mod, params = relay.frontend.from_pytorch(input_model, input_shapes) - executor = relay.create_executor("vm", mod=mod, ctx=tvm.cpu(0), - target="llvm") - evaluator = executor.evaluate() + for tgt in targets: + print("Running on target", tgt) + ctx = tvm.context(tgt, 0) - # Inference - for name, inp in zip(input_names, input_data): - params[name] = inp.numpy() - vm_res = evaluator(**params) + executor = relay.create_executor("vm", mod=mod, ctx=ctx, target=tgt) + evaluator = executor.evaluate() - # Baseline result - with torch.no_grad(): - pt_result = input_model(*input_data) + # Inference + for name, inp in zip(input_names, input_data): + params[name] = inp.numpy() + vm_res = evaluator(**params) - # Verify the accuracy - if not isinstance(pt_result, torch.Tensor): - tvm_res = vm_res.asnumpy().item() - assert pt_result == tvm_res - else: - tvm.testing.assert_allclose(vm_res.asnumpy(), pt_result.numpy(), - rtol=1e-5, atol=1e-5) + # Baseline result + with torch.no_grad(): + pt_result = input_model(*input_data) + + # Verify the accuracy + if not isinstance(pt_result, torch.Tensor): + tvm_res = vm_res.asnumpy().item() + assert pt_result == tvm_res + else: + tvm.testing.assert_allclose(vm_res.asnumpy(), pt_result.numpy(), + rtol=1e-5, atol=1e-5) @tvm.testing.uses_gpu @@ -1915,7 +1924,7 @@ def forward(self, inp): ] for pt_model in models: - verify_script_model(pt_model.eval(), [(10, 20)]) + verify_script_model(pt_model.eval(), [(10, 20)], _get_default_vm_targets()) @tvm.testing.uses_gpu @@ -1953,7 +1962,7 @@ def forward(self, xs): y, h = self.cell(xs[i], h) return y - verify_script_model(RNNLoop().eval(), [(10, 10, 4)]) + verify_script_model(RNNLoop().eval(), [(10, 10, 4)], _get_default_vm_targets()) @tvm.testing.uses_gpu From b77691f311513f23540c3ee3f650e84e485ac27f Mon Sep 17 00:00:00 2001 From: masa Date: Tue, 8 Sep 2020 05:18:44 +0900 Subject: [PATCH 3/3] fix pooling stride None case --- python/tvm/relay/frontend/pytorch.py | 22 ++++-------------- tests/python/frontend/pytorch/test_forward.py | 23 +++++++++++++++++++ 2 files changed, 28 insertions(+), 17 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 8d326957bc78..8d850093f71b 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -685,7 +685,7 @@ def _impl(inputs, input_types): data = inputs[0] pool_size = inputs[1] - strides = inputs[2] + strides = inputs[2] if inputs[2] else pool_size padding = inputs[3] dilation = inputs[4] ceil_mode = int(inputs[5]) @@ -708,7 +708,7 @@ def _impl(inputs, input_types): data = inputs[0] pool_size = inputs[1] - strides = inputs[2] + strides = inputs[2] if inputs[2] else pool_size padding = inputs[3] dilation = inputs[4] ceil_mode = int(inputs[5]) @@ -725,7 +725,7 @@ def _impl(inputs, input_types): data = inputs[0] pool_size = inputs[1] - strides = inputs[2] + strides = inputs[2] if inputs[2] else pool_size padding = inputs[3] dilation = inputs[4] ceil_mode = int(inputs[5]) @@ -1175,14 +1175,8 @@ def _impl(inputs, input_types): data = inputs[0] pool_size = inputs[1] - - if inputs[2]: - strides = inputs[2] - else: - strides = pool_size - + strides = inputs[2] if inputs[2] else pool_size padding = inputs[3] - ceil_mode = int(inputs[4]) count_include_pad = int(inputs[5]) @@ -1206,14 +1200,8 @@ def _impl(inputs, input_types): data = inputs[0] pool_size = inputs[1] - - if inputs[2]: - strides = inputs[2] - else: - strides = pool_size - + strides = inputs[2] if inputs[2] else pool_size padding = inputs[3] - ceil_mode = int(inputs[4]) count_include_pad = int(inputs[5]) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 4215994715d2..e6517003eaac 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -674,6 +674,13 @@ def test_forward_maxpool2d(): stride=2).eval(), input_data) + # A functional variant (default strides = None case) + class MaxPool2D(Module): + def forward(self, *args): + return torch.nn.functional.max_pool2d(args[0], kernel_size=[10, 10]) + + verify_model(MaxPool2D(), input_data=input_data) + class MaxPool2DWithIndices(Module): def __init__(self): super(MaxPool2DWithIndices, self).__init__() @@ -700,6 +707,14 @@ def test_forward_maxpool1d(): stride=2).eval(), input_data) + # A functional variant (default strides = None case) + class MaxPool1D(Module): + def forward(self, *args): + return torch.nn.functional.max_pool1d(args[0], kernel_size=10) + + verify_model(MaxPool1D(), input_data=input_data) + + @tvm.testing.uses_gpu def test_forward_maxpool3d(): torch.set_grad_enabled(False) @@ -715,6 +730,14 @@ def test_forward_maxpool3d(): stride=2).eval(), input_data) + # A functional variant (default strides = None case) + class MaxPool3D(Module): + def forward(self, *args): + return torch.nn.functional.max_pool3d(args[0], kernel_size=[10, 10, 10]) + + verify_model(MaxPool3D(), input_data=input_data) + + @tvm.testing.uses_gpu def test_forward_split(): torch.set_grad_enabled(False)