Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions nnvm/python/nnvm/testing/yolo_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def do_nms_sort(dets, classes, thresh):
if _box_iou(a, b) > thresh:
dets[j]['prob'][k] = 0

def draw_detections(im, dets, thresh, names, classes):
def draw_detections(font_path, im, dets, thresh, names, classes):
"Draw the markings around the detected region"
for det in dets:
labelstr = []
Expand Down Expand Up @@ -198,7 +198,7 @@ def draw_detections(im, dets, thresh, names, classes):
if bot > imh-1:
bot = imh-1
_draw_box_width(im, left, top, right, bot, width, red, green, blue)
label = _get_label(''.join(labelstr), rgb)
label = _get_label(font_path, ''.join(labelstr), rgb)
_draw_label(im, top + width, left, label, rgb)

def _get_pixel(im, x, y, c):
Expand All @@ -223,15 +223,15 @@ def _draw_label(im, r, c, label, rgb):
val = _get_pixel(label, i, j, k)
_set_pixel(im, i+c, j+r, k, val)#rgb[k] * val)

def _get_label(labelstr, rgb):
def _get_label(font_path, labelstr, rgb):
from PIL import Image
from PIL import ImageDraw
from PIL import ImageFont

text = labelstr
colorText = "black"
testDraw = ImageDraw.Draw(Image.new('RGB', (1, 1)))
font = ImageFont.truetype("arial.ttf", 25)
font = ImageFont.truetype(font_path, 25)
width, height = testDraw.textsize(labelstr, font=font)
img = Image.new('RGB', (width, height), color=(int(rgb[0]*255), int(rgb[1]*255),
int(rgb[2]*255)))
Expand Down
4 changes: 2 additions & 2 deletions nnvm/tutorials/from_darknet.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@
# do the detection and bring up the bounding boxes
thresh = 0.5
nms_thresh = 0.45
img = nnvm.testing.darknet.load_image_color(test_image)
img = nnvm.testing.darknet.load_image_color(img_path)
_, im_h, im_w = img.shape
dets = nnvm.testing.yolo_detection.fill_network_boxes((netw, neth), (im_w, im_h), thresh,
1, tvm_out)
Expand All @@ -172,6 +172,6 @@

names = [x.strip() for x in content]

nnvm.testing.yolo_detection.draw_detections(img, dets, thresh, names, last_layer.classes)
nnvm.testing.yolo_detection.draw_detections(font_path, img, dets, thresh, names, last_layer.classes)
plt.imshow(img.transpose(1, 2, 0))
plt.show()