-
Notifications
You must be signed in to change notification settings - Fork 3.8k
[Torch] Fix PyTorch NMS conversion for negative scores #7137
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
5cf4ee4
cf401cf
5966041
9af6eda
63c43fb
39e007d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,6 +23,7 @@ | |
|
|
||
| import tvm | ||
|
|
||
| import tvm.testing | ||
| from tvm import relay | ||
| from tvm.runtime.vm import VirtualMachine | ||
| from tvm.contrib.download import download | ||
|
|
@@ -70,7 +71,7 @@ def generate_jit_model(index): | |
| ] | ||
|
|
||
| model_func = model_funcs[index] | ||
| model = TraceWrapper(model_func(pretrained=True)) | ||
| model = TraceWrapper(model_func(pretrained=True, rpn_pre_nms_top_n_test=200)) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Glad to see
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the default parameter 1000 they picked is fairly conservative. This means for each level in the feature pyramid, of which there is 5 if we use resnet 50 backbone, we get maximum of 1000 x 5 boxes as input to RPN. They have another parameter |
||
|
|
||
| model.eval() | ||
| inp = torch.Tensor(np.random.uniform(0.0, 250.0, size=(1, 3, in_size, in_size))) | ||
|
|
@@ -94,46 +95,40 @@ def test_detection_models(): | |
| download(img_url, img) | ||
|
|
||
| input_shape = (1, 3, in_size, in_size) | ||
| target = "llvm" | ||
|
|
||
| input_name = "input0" | ||
| shape_list = [(input_name, input_shape)] | ||
| score_threshold = 0.9 | ||
|
|
||
| scripted_model = generate_jit_model(1) | ||
| mod, params = relay.frontend.from_pytorch(scripted_model, shape_list) | ||
|
|
||
| with tvm.transform.PassContext(opt_level=3, disabled_pass=["FoldScaleAxis"]): | ||
| vm_exec = relay.vm.compile(mod, target=target, params=params) | ||
|
|
||
| ctx = tvm.cpu() | ||
| vm = VirtualMachine(vm_exec, ctx) | ||
| data = process_image(img) | ||
| pt_res = scripted_model(data) | ||
| data = data.detach().numpy() | ||
| vm.set_input("main", **{input_name: data}) | ||
| tvm_res = vm.run() | ||
|
|
||
| # Note: due to accumulated numerical error, we can't directly compare results | ||
| # with pytorch output. Some boxes might have a quite tiny difference in score | ||
| # and the order can become different. We just measure how many valid boxes | ||
| # there are for input image. | ||
| pt_scores = pt_res[1].detach().numpy().tolist() | ||
| tvm_scores = tvm_res[1].asnumpy().tolist() | ||
| num_pt_valid_scores = num_tvm_valid_scores = 0 | ||
|
|
||
| for score in pt_scores: | ||
| if score >= score_threshold: | ||
| num_pt_valid_scores += 1 | ||
| else: | ||
| break | ||
|
|
||
| for score in tvm_scores: | ||
| if score >= score_threshold: | ||
| num_tvm_valid_scores += 1 | ||
| else: | ||
| break | ||
|
|
||
| assert num_pt_valid_scores == num_tvm_valid_scores, ( | ||
| "Output mismatch: Under score threshold {}, Pytorch has {} valid " | ||
| "boxes while TVM has {}.".format(score_threshold, num_pt_valid_scores, num_tvm_valid_scores) | ||
| ) | ||
| data_np = data.detach().numpy() | ||
|
|
||
| with torch.no_grad(): | ||
| pt_res = scripted_model(data) | ||
|
|
||
| for target in ["llvm", "cuda"]: | ||
| with tvm.transform.PassContext(opt_level=3): | ||
| vm_exec = relay.vm.compile(mod, target=target, params=params) | ||
|
|
||
| ctx = tvm.context(target, 0) | ||
| vm = VirtualMachine(vm_exec, ctx) | ||
|
|
||
| vm.set_input("main", **{input_name: data_np}) | ||
| tvm_res = vm.run() | ||
|
|
||
| # Bounding boxes | ||
| tvm.testing.assert_allclose( | ||
| pt_res[0].cpu().numpy(), tvm_res[0].asnumpy(), rtol=1e-5, atol=1e-5 | ||
| ) | ||
| # Scores | ||
| tvm.testing.assert_allclose( | ||
| pt_res[1].cpu().numpy(), tvm_res[1].asnumpy(), rtol=1e-5, atol=1e-5 | ||
| ) | ||
| # Class ids | ||
| np.testing.assert_equal(pt_res[2].cpu().numpy(), tvm_res[2].asnumpy()) | ||
|
|
||
| score_threshold = 0.9 | ||
| print("Num boxes:", pt_res[0].cpu().numpy().shape[0]) | ||
| print("Num valid boxes:", np.sum(pt_res[1].cpu().numpy() >= score_threshold)) | ||
Uh oh!
There was an error while loading. Please reload this page.