diff --git a/nnvm/python/nnvm/testing/yolo_detection.py b/nnvm/python/nnvm/testing/yolo_detection.py index 9ecb49ae04f0..bdf9efe62de4 100644 --- a/nnvm/python/nnvm/testing/yolo_detection.py +++ b/nnvm/python/nnvm/testing/yolo_detection.py @@ -165,7 +165,7 @@ def do_nms_sort(dets, classes, thresh): if _box_iou(a, b) > thresh: dets[j]['prob'][k] = 0 -def draw_detections(im, dets, thresh, names, classes): +def draw_detections(font_path, im, dets, thresh, names, classes): "Draw the markings around the detected region" for det in dets: labelstr = [] @@ -198,7 +198,7 @@ def draw_detections(im, dets, thresh, names, classes): if bot > imh-1: bot = imh-1 _draw_box_width(im, left, top, right, bot, width, red, green, blue) - label = _get_label(''.join(labelstr), rgb) + label = _get_label(font_path, ''.join(labelstr), rgb) _draw_label(im, top + width, left, label, rgb) def _get_pixel(im, x, y, c): @@ -223,7 +223,7 @@ def _draw_label(im, r, c, label, rgb): val = _get_pixel(label, i, j, k) _set_pixel(im, i+c, j+r, k, val)#rgb[k] * val) -def _get_label(labelstr, rgb): +def _get_label(font_path, labelstr, rgb): from PIL import Image from PIL import ImageDraw from PIL import ImageFont @@ -231,7 +231,7 @@ def _get_label(labelstr, rgb): text = labelstr colorText = "black" testDraw = ImageDraw.Draw(Image.new('RGB', (1, 1))) - font = ImageFont.truetype("arial.ttf", 25) + font = ImageFont.truetype(font_path, 25) width, height = testDraw.textsize(labelstr, font=font) img = Image.new('RGB', (width, height), color=(int(rgb[0]*255), int(rgb[1]*255), int(rgb[2]*255))) diff --git a/nnvm/tutorials/from_darknet.py b/nnvm/tutorials/from_darknet.py index 607af1038628..857ef46015cd 100644 --- a/nnvm/tutorials/from_darknet.py +++ b/nnvm/tutorials/from_darknet.py @@ -153,7 +153,7 @@ # do the detection and bring up the bounding boxes thresh = 0.5 nms_thresh = 0.45 -img = nnvm.testing.darknet.load_image_color(test_image) +img = nnvm.testing.darknet.load_image_color(img_path) _, im_h, im_w = img.shape dets = nnvm.testing.yolo_detection.fill_network_boxes((netw, neth), (im_w, im_h), thresh, 1, tvm_out) @@ -172,6 +172,6 @@ names = [x.strip() for x in content] -nnvm.testing.yolo_detection.draw_detections(img, dets, thresh, names, last_layer.classes) +nnvm.testing.yolo_detection.draw_detections(font_path, img, dets, thresh, names, last_layer.classes) plt.imshow(img.transpose(1, 2, 0)) plt.show()