faster nms and stuff

This commit is contained in:
Joseph Redmon
2018-03-15 15:23:14 -07:00
parent 0b64cb4dd3
commit 0f110834f4
10 changed files with 126 additions and 93 deletions

View File

@ -63,7 +63,7 @@ make_image.argtypes = [c_int, c_int, c_int]
make_image.restype = IMAGE
get_network_boxes = lib.get_network_boxes
get_network_boxes.argtypes = [c_void_p, c_int, c_int, c_float, c_float, POINTER(c_int), c_int]
get_network_boxes.argtypes = [c_void_p, c_int, c_int, c_float, c_float, POINTER(c_int), c_int, POINTER(c_int)]
get_network_boxes.restype = POINTER(DETECTION)
make_network_boxes = lib.make_network_boxes
@ -76,10 +76,6 @@ free_detections.argtypes = [POINTER(DETECTION), c_int]
free_ptrs = lib.free_ptrs
free_ptrs.argtypes = [POINTER(c_void_p), c_int]
num_boxes = lib.num_boxes
num_boxes.argtypes = [c_void_p]
num_boxes.restype = c_int
network_predict = lib.network_predict
network_predict.argtypes = [c_void_p, POINTER(c_float)]
@ -128,9 +124,11 @@ def classify(net, meta, im):
def detect(net, meta, image, thresh=.5, hier_thresh=.5, nms=.45):
im = load_image(image, 0, 0)
num = num_boxes(net)
num = c_int(0)
pnum = pointer(num)
predict_image(net, im)
dets = get_network_boxes(net, im.w, im.h, thresh, hier_thresh, None, 0)
dets = get_network_boxes(net, im.w, im.h, thresh, hier_thresh, None, 0, pnum)
num = pnum[0]
if (nms): do_nms_obj(dets, num, meta.classes, nms);
res = []