added real numpy interface to detect

This commit is contained in:
Daniel Gordon 2018-03-15 18:13:29 -07:00
parent 911663e082
commit fee68e39ce

View File

@ -13,9 +13,21 @@ def sample(probs):
return len(probs)-1 return len(probs)-1
def c_array(ctype, values): def c_array(ctype, values):
arr = (ctype*len(values))() new_values = values.ctypes.data_as(POINTER(ctype))
arr[:] = values return new_values
return arr
def array_to_image(arr):
import numpy as np
# need to return old values to avoid python freeing memory
arr = arr.transpose(2,0,1)
c = arr.shape[0]
h = arr.shape[1]
w = arr.shape[2]
arr = np.ascontiguousarray(arr.flat, dtype=np.float32) / 255.0
data = arr.ctypes.data_as(POINTER(c_float))
im = IMAGE(w,h,c,data)
return im, arr
class BOX(Structure): class BOX(Structure):
_fields_ = [("x", c_float), _fields_ = [("x", c_float),
@ -42,7 +54,6 @@ class METADATA(Structure):
_fields_ = [("classes", c_int), _fields_ = [("classes", c_int),
("names", POINTER(c_char_p))] ("names", POINTER(c_char_p))]
#lib = CDLL("/home/pjreddie/documents/darknet/libdarknet.so", RTLD_GLOBAL) #lib = CDLL("/home/pjreddie/documents/darknet/libdarknet.so", RTLD_GLOBAL)
lib = CDLL("libdarknet.so", RTLD_GLOBAL) lib = CDLL("libdarknet.so", RTLD_GLOBAL)
@ -141,7 +152,27 @@ def detect(net, meta, image, thresh=.5, hier_thresh=.5, nms=.45):
free_image(im) free_image(im)
free_detections(dets, num) free_detections(dets, num)
return res return res
def detect_numpy(net, meta, image, thresh=.5, hier_thresh=.5, nms=.45):
im, arr = array_to_image(image)
num = c_int(0)
pnum = pointer(num)
predict_image(net, im)
dets = get_network_boxes(net, im.w, im.h, thresh, hier_thresh, None, 0, pnum)
num = pnum[0]
if (nms): do_nms_obj(dets, num, meta.classes, nms);
res = []
for j in range(num):
for i in range(meta.classes):
if dets[j].prob[i] > 0:
b = dets[j].bbox
res.append((meta.names[i], dets[j].prob[i], (b.x, b.y, b.w, b.h)))
res = sorted(res, key=lambda x: -x[1])
free_detections(dets, num)
return res
if __name__ == "__main__": if __name__ == "__main__":
#net = load_net("cfg/densenet201.cfg", "/home/pjreddie/trained/densenet201.weights", 0) #net = load_net("cfg/densenet201.cfg", "/home/pjreddie/trained/densenet201.weights", 0)
#im = load_image("data/wolf.jpg", 0, 0) #im = load_image("data/wolf.jpg", 0, 0)
@ -150,7 +181,27 @@ if __name__ == "__main__":
#print r[:10] #print r[:10]
net = load_net("cfg/tiny-yolo.cfg", "tiny-yolo.weights", 0) net = load_net("cfg/tiny-yolo.cfg", "tiny-yolo.weights", 0)
meta = load_meta("cfg/coco.data") meta = load_meta("cfg/coco.data")
r = detect(net, meta, "data/dog.jpg") import scipy.misc
import time
'''
t_start = time.time()
for ii in range(100):
r = detect(net, meta, 'data/dog.jpg')
print(time.time() - t_start)
print(r)
image = scipy.misc.imread('data/dog.jpg')
for ii in range(100):
scipy.misc.imsave('/tmp/image.jpg', image)
r = detect(net, meta, '/tmp/image.jpg')
print(time.time() - t_start)
print(r)
'''
image = scipy.misc.imread('data/dog.jpg')
t_start = time.time()
for ii in range(100):
r = detect_numpy(net, meta, image)
print(time.time() - t_start)
print(r) print(r)